Starting MLflow with FastAI (v2)

This is not meant as a comprehensive guide, rather just a place to start and get things up and running. Read further below for a background and explanation.

  1. Create a custom callback

 

## Tracking Classfrom mlflow.tracking import MlflowClientfrom mlflow.entities.run import Runfrom typing import Listclass MLFlowTracking(Callback):    "A `LearnerCallback` that tracks the loss and other metrics into MLFlow"    def __init__(self, metric_names:List[str], client:MlflowClient, run_id:Run):        self.client = client        self.run_id = run_id        self.metric_names = metric_names                def after_epoch(self):        "Compare the last value to the best up to now"        for metric_name in self.metric_names:            m_idx = list(self.recorder.metric_names[1:]).index(metric_name)            if len(self.recorder.values) > 0:                val = self.recorder.values[-1][m_idx]                self.client.log_metric(self.run_id, metric_name, np.float(val))
Custom Callback Class

2. Initialize the MLflow Experiment, Client and Run

# Make sure to install mlflow using pip, conda or your favourite package managerimport mlflowfrom mlflow.tracking import MlflowClientTRACKING_URL = "<full url, e.g: http://example.com:5000>"EXPERIMENT_NAME = "example_experiment"# Create the Clientmlfclient = mlflow.tracking.MlflowClient(tracking_uri= TRACKING_URL)# Check if the experiment already exists, or create itmlfexp = mlfclient.get_experiment_by_name(EXPERIMENT_NAME)if mlfexp is None:    mlfexp_id = mlfclient.create_experiment(EXPERIMENT_NAME)    mlfexp = mlfclient.get_experiment_by_name(EXPERIMENT_NAME)    mlrun = mlfclient.create_run(experiment_id= mlfexp.experiment_id)
Initialize the MLflow Runner and Client

3. Log the experiment’s hyperparameters

# setup up a params dictionaryparams = {    'crop_size': 250,    'epochs': 20}# log each of the params items to MLflow (for that experiment's run)for k, v in params.items():    mlfclient.log_param(run_id=mlrun.info.run_uuid, key=k, value=v)    # Make sure to use params in your data loader and experiments
Hyperparameter logging

4. Add the custom callback to the learner

# Assuming dataloaders is already setup learn = cnn_learner(dls, resnet18, metrics=error_rate, cbs=[MLFlowTracking(metric_names=['valid_loss', 'train_loss', 'error_rate'], client=mlfclient, run_id=mlrun.info.run_uuid)])# This should start logging the metrics and losses in your environment as well as MLflowlearn.fit(params['epochs'])
Custom Callback and Metrics

FastAI is a powerful library with a set of good features to make ML fun. One of the things I was struggling with is losing track of my experiment’s configuration and to the metrics and losses they generated.

MLflow seemed like a good fit at least for what I was trying to accomplish. Most documentations around setting up MLflow and fastai are stale (covering v1). MLflow does have fast.ai autolog feature but, as per the documentations (and trials), looks like it works only for version fast.ai 1.0.61 or earlier.

The above setup, although very crude, works for the simple setup I have. It assumes access to an MLflow server with accessible URI (TRACKING_URI). One caveat is that metric values tend to be empty for the first couple of epochs and, with more understanding of fastai Callbacks and recorder, I might be able to track down.

Happy Deep Learning

## Tracking Classfrom mlflow.tracking 
import MlflowClient from mlflow.entities.run 
import Run from typing 
import List

class MLFlowTracking(Callback):
  "A `LearnerCallback` that tracks the loss and other metrics into MLFlow"
  
  def __init__(self, metric_names:List[str], client:MlflowClient, run_id:Run):
    self.client = client
    self.run_id = run_id
    self.metric_names = metric_names
    
  def after_epoch(self):
    "Compare the last value to the best up to now" 
    for metric_name in self.metric_names: 
      m_idx = list(self.recorder.metric_names[1:]).index(metric_name) 
      if len(self.recorder.values) > 0: 
        val = self.recorder.values[-1][m_idx]
        self.client.log_metric(self.run_id, metric_name, np.float(val))