Skip to content

Workflow Orchestration for Machine Learning


Create, schedule and monitor workflows by creating scalable pipelines.
Goku Mohandas
Goku Mohandas
· ·
Repository

📬  Receive new lessons straight to your inbox (once a month) and join 35K+ developers in learning how to responsibly deliver value with ML.

`

Intuition

So far we've implemented our DataOps (ETL, preprocessing, validation, etc.) and MLOps (optimization, training, evaluation, etc.) workflows as Python function calls. This has worked well since our dataset is static and small. But happens when we need to:

  • schedule these workflows as new data arrives?
  • scale these workflows as our data grows?
  • share these workflows to downstream consumers?
  • monitor these workflows?

We'll need to break down our end-to-end ML pipeline into individual workflows that be orchestrated as needed. There are several tools that can help us so this such as Airflow, Prefect, Dagster, Luigi and even some ML focused options such as Metaflow, Flyte, KubeFlow Pipelines, Vertex pipelines, etc. We'll be creating our workflows using AirFlow for its:

  • wide adoption of the open source platform in industry
  • Python based software development kit (SDK)
  • integration with the ecosystem (data ingestion, processing, etc.)
  • ability to run locally and scale easily
  • maturity over the years and part of the apache ecosystem

We'll be running Airflow locally but we can easily scale it by running on a managed cluster platform where we can run Python, Hadoop, Spark, etc. on large batch processing jobs (AWS EMR, Google Cloud's Dataproc, on-prem hardware, etc.).

Airflow

Before we create our specific pipelines, let's understand and implement Airflow's overarching concepts that will allow us to "author, schedule, and monitor workflows".

Set up

To install and run Airflow, we can either do so locally or with Docker. If using docker-compose to run Airflow inside Docker containers, we'll want to allocate at least 4 GB in memory.

# Configurations
export AIRFLOW_HOME=${PWD}/airflow
AIRFLOW_VERSION=2.3.3
PYTHON_VERSION="$(python --version | cut -d " " -f 2 | cut -d "." -f 1-2)"
CONSTRAINT_URL="https://raw.githubusercontent.com/apache/airflow/constraints-${AIRFLOW_VERSION}/constraints-${PYTHON_VERSION}.txt"

# Install Airflow (may need to upgrade pip)
pip install "apache-airflow==${AIRFLOW_VERSION}" --constraint "${CONSTRAINT_URL}"

# Initialize DB (SQLite by default)
airflow db init

This will create an airflow directory with the following components:

airflow/
├── logs/
└── airflow.cfg
├── airflow.db
├── unittests.cfg
└── webserver_config.py

We're going to edit the airflow.cfg file to best fit our needs:

# Inside airflow.cfg
enable_xcom_pickling = True  # needed for Great Expectations airflow provider
load_examples = False  # don't clutter webserver with examples

And we'll perform a reset to implement these configuration changes.

airflow db reset -y

Now we're ready to initialize our database with an admin user, which we'll use to login to access our workflows in the webserver.

# We'll be prompted to enter a password
airflow users create \
    --username admin \
    --firstname FIRSTNAME \
    --lastname LASTNAME \
    --role Admin \
    --email EMAIL

Webserver

Once we've created a user, we're ready to launch the webserver and log in using our credentials.

# Launch webserver
export AIRFLOW_HOME=${PWD}/airflow
airflow webserver --port 8080  # http://localhost:8080

The webserver allows us to run and inspect workflows, establish connections to external data storage, manager users, etc. through a UI. Similarly, we could also use Airflow's REST API or Command-line interface (CLI) to perform the same operations. However, we'll be using the webserver because it's convenient to visually inspect our workflows.

airflow webserver

We'll explore the different components of the webserver as we learn about Airflow and implement our workflows.

Scheduler

Next, we need to launch our scheduler, which will execute and monitor the tasks in our workflows. The schedule executes tasks by reading from the metadata database and ensures the task has what it needs to finish running. We'll go ahead and execute the following commands on the separate terminal window:

# Launch scheduler (in separate terminal)
export AIRFLOW_HOME=${PWD}/airflow
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
airflow scheduler

Executor

As our scheduler reads from the metadata database, the executor determines what worker processes are necessary for the task to run to completion. Since our default database SQLlite, which can't support multiple connections, our default executor is the Sequential Executor. However, if we choose a more production-grade database option such as PostgresSQL or MySQL, we can choose scalable Executor backends Celery, Kubernetes, etc. For example, running Airflow with Docker uses PostgresSQL as the database and so uses the Celery Executor backend to run tasks in parallel.

DAGs

Workflows are defined by directed acyclic graphs (DAGs), whose nodes represent tasks and edges represent the data flow relationship between the tasks. Direct and acyclic implies that workflows can only execute in one direction and a previous, upstream task cannot run again once a downstream task has started.

basic DAG

DAGs can be defined inside Python workflow scripts inside the airflow/dags directory and they'll automatically appear (and continuously be updated) on the webserver. Before we start creating our DataOps and MLOps workflows, we'll learn about Airflow's concepts via an example DAG outlined in airflow/dags/example.py. Execute the following commands in a new (3rd) terminal window:

mkdir airflow/dags
touch airflow/dags/example.py

Inside each workflow script, we can define some default arguments that will apply to all DAGs within that workflow.

1
2
3
4
# Default DAG args
default_args = {
    "owner": "airflow",
}

There are many more default arguments and we'll cover them as we go through the concepts.

We can initialize DAGs with many parameters (which will override the same parameters in default_args) and in several different ways:

  • using a with statement

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    from airflow import DAG
    
    with DAG(
        dag_id="example",
        description="Example DAG",
        default_args=default_args,
        schedule_interval=None,
        start_date=days_ago(2),
        tags=["example"],
    ) as example:
        # Define tasks
        pass
    

  • using the dag decorator

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    from airflow.decorators import dag
    
    @dag(
        dag_id="example",
        description="Example DAG",
        default_args=default_args,
        schedule_interval=None,
        start_date=days_ago(2),
        tags=["example"],
    )
    def example():
        # Define tasks
        pass
    

There are many parameters that we can initialize our DAGs with, including a start_date and a schedule_interval. While we could have our workflows execute on a temporal cadence, many ML workflows are initiated by events, which we can map using sensors and hooks to external databases, file systems, etc.

Tasks

Tasks are the operations that are executed in a workflow and are represented by nodes in a DAG. Each task should be a clearly defined single operation and it should be idempotent, which means we can execute it multiple times and expect the same result and system state. This is important in the event we need to retry a failed task and don't have to worry about resetting the state of our system. Like DAGs, there are several different ways to implement tasks:

  • using the task decorator

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    from airflow.decorators import dag, task
    from airflow.utils.dates import days_ago
    
    @dag(
        dag_id="example",
        description="Example DAG with task decorators",
        default_args=default_args,
        schedule_interval=None,
        start_date=days_ago(2),
        tags=["example"],
    )
    def example():
        @task
        def task_1():
            return 1
        @task
        def task_2(x):
            return x+1
    

  • using Operators

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    from airflow.decorators import dag
    from airflow.operators.bash_operator import BashOperator
    from airflow.utils.dates import days_ago
    
    @dag(
        dag_id="example",
        description="Example DAG with Operators",
        default_args=default_args,
        schedule_interval=None,
        start_date=days_ago(2),
        tags=["example"],
    )
    def example():
        # Define tasks
        task_1 = BashOperator(task_id="task_1", bash_command="echo 1")
        task_2 = BashOperator(task_id="task_2", bash_command="echo 2")
    

Though the graphs are directed, we can establish certain trigger rules for each task to execute on conditional successes or failures of the parent tasks.

Operators

The first method of creating tasks involved using Operators, which defines what exactly the task will be doing. Airflow has many built-in Operators such as the BashOperator or PythonOperator, which allow us to execute bash and Python commands respectively.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# BashOperator
from airflow.operators.bash_operator import BashOperator
task_1 = BashOperator(task_id="task_1", bash_command="echo 1")

# PythonOperator
from airflow.operators.python import PythonOperator
task_2 = PythonOperator(
    task_id="task_2",
    python_callable=foo,
    op_kwargs={"arg1": ...})

There are also many other Airflow native Operators (email, S3, MySQL, Hive, etc.), as well as community maintained provider packages (Kubernetes, Snowflake, Azure, AWS, Salesforce, Tableau, etc.), to execute tasks specific to certain platforms or tools.

We can also create our own custom Operators by extending the BashOperator class.

Relationships

Once we've defined our tasks using Operators or as decorated functions, we need to define the relationships between them (edges). The way we define the relationships depends on how our tasks were defined:

  • using decorated functions

    1
    2
    3
    # Task relationships
    x = task_1()
    y = task_2(x=x)
    

  • using Operators

    1
    2
    3
    # Task relationships
    task_1 >> task_2  # same as task_1.set_downstream(task_2) or
                      # task_2.set_upstream(task_1)
    

In both scenarios, we'll setting task_2 as the downstream task to task_1.

Note

We can even create intricate DAGs by using these notations to define the relationships.

1
2
3
task_1 >> [task_2_1, task_2_2] >> task_3
task_2_2 >> task_4
[task_3, task_4] >> task_5
DAG

XComs

When we use task decorators, we can see how values can be passed between tasks. But, how can we pass values when using Operators? Airflow uses XComs (cross communications) objects, defined with a key, value, timestamp and task_id, to push and pull values between tasks. When we use decorated functions, XComs are being used under the hood but it's abstracted away, allowing us to pass values amongst Python functions seamlessly. But when using Operators, we'll need to explicitly push and pull the values as we need it.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def _task_1(ti):
    x = 2
    ti.xcom_push(key="x", value=x)

def _task_2(ti):
    x = ti.xcom_pull(key="x", task_ids=["task_1"])[0]
    y = x + 3
    ti.xcom_push(key="y", value=y)

@dag(
    dag_id="example",
    description="Example DAG",
    default_args=default_args,
    schedule_interval=None,
    start_date=days_ago(2),
    tags=["example"],
)
def example2():
    # Tasks
    task_1 = PythonOperator(task_id="task_1", python_callable=_task_1)
    task_2 = PythonOperator(task_id="task_2", python_callable=_task_2)
    task_1 >> task_2

We can also view our XComs on the webserver by going to Admin >> XComs:

xcoms

Warning

The data we pass between tasks should be small (metadata, metrics, etc.) because Airflow's metadata database is not equipped to hold large artifacts. However, if we do need to store and use the large results of our tasks, it's best to use an external data storage (blog storage, model registry, etc.) and perform heavy processing using Spark or inside data systems like a data warehouse.

DAG runs

Once we've defined the tasks and their relationships, we're ready to run our DAGs. We'll start defining our DAG like so:

1
2
3
# Run DAGs
example1_dag = example_1()
example2_dag = example_2()

If we refresh our webserver page (http://localhost:8080/), the new DAG will have appeared.

Manual

Our DAG is initially paused since we specified dags_are_paused_at_creation = True inside our airflow.cfg configuration, so we'll have to manually execute this DAG by clicking on it > unpausing it (toggle) > triggering it (button). To view the logs for any of the tasks in our DAG run, we can click on the task > Log.

triggering a DAG

Note

We could also use Airflow's REST API (will configured authorization) or Command-line interface (CLI) to inspect and trigger workflows (and a whole lot more). Or we could even use the trigger_dagrun Operator to trigger DAGs from within another workflow.

# CLI to run dags
airflow dags trigger <DAG_ID>

Interval

Had we specified a start_date and schedule_interval when defining the DAG, it would have have automatically executed at the appropriate times. For example, the DAG below will have started two days ago and will be triggered at the start of every day.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
from airflow.decorators import dag
from airflow.utils.dates import days_ago
from datetime import timedelta

@dag(
    dag_id="example",
    default_args=default_args,
    schedule_interval=timedelta(days=1),
    start_date=days_ago(2),
    tags=["example"],
    catch_up=False,
)

Warning

Depending on the start_date and schedule_interval, our workflow should have been triggered several times and Airflow will try to catchup to the current time. We can avoid this by setting catchup=False when defining the DAG. We can also set this configuration as part of the default arguments:

1
2
3
4
default_args = {
    "owner": "airflow",
    "catch_up": False,
}

However, if we did want to run particular runs in the past, we can manually backfill what we need.

We could also specify a cron expression for our schedule_interval parameter or even use cron presets.

Airflow's Scheduler will run our workflows one schedule_interval from the start_date. For example, if we want our workflow to start on 01-01-1983 and run @daily, then the first run will be immediately after 01-01-1983T11:59.

Sensors

While it may make sense to execute many data processing workflows on a scheduled interval, machine learning workflows may require more nuanced triggers. We shouldn't be wasting compute by running executing our workflows just in case we have new data. Instead, we can use sensors to trigger workflows when some external condition is met. For example, we can initiate data processing when a new batch of annotated data appears in a database or when a specific file appears in a file system, etc.

There's so much more to Airflow (monitoring, Task groups, smart senors, etc.) so be sure to explore them as you need them by using the official documentation.

DataOps

Now that we've reviewed Airflow's major concepts, we're ready to create the DataOps workflow for our application. It involves a series of tasks around extraction, transformation, loading, validation, etc. We're going to use a simplified data stack (local file, validation, etc.) as opposed to a production data stack but the overall workflows are similar. Instead of extracting data from a source, validating and transforming it and then loading into a data warehouse, we're going to perform ETL from a local file and load the processed data into another local file.

ETL pipelines in production

Note

We'll be breaking apart our etl_data() function from our tagifai/main.py script so that we can show what the proper data validation tasks look like in production workflows.

touch airflow/dags/workflows.py
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
from airflow.decorators import dag, task
from airflow.utils.dates import days_ago

# Default DAG args
default_args = {
    "owner": "airflow",
    "catch_up": False,
}

# Define DAG
@dag(
    dag_id="DataOps",
    description="DataOps tasks.",
    default_args=default_args,
    schedule_interval=None,
    start_date=days_ago(2),
    tags=["dataops"],
)
def dataops():
    pass

ETL vs. ELT

If using a data warehouse (ex. Snowflake), it's common to see ELT (extract-load-transform) workflows to have a permanent location for all historical data. Learn more about the data stack and the different workflow options here.

1
2
3
4
5
6
7
8
# Define DAG
(
    extract
    >> [validate_projects, validate_tags]
    >> load
    >> transform
    >> validate_transforms
)
dataops workflow

Extraction

To keep things simple, we'll continue to keep our data as a local file but in a real production setting, our data can come from a wide variety of data systems.

Note

Ideally, the data labeling workflows would have occurred prior to the DataOps workflows. Depending on the task, it may involve natural labels, where the event that occurred is the label. Or there may be explicit manual labeling workflows that need to be inspected and approved.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
def _extract(ti):
    """Extract from source (ex. DB, API, etc.)
    Our simple ex: extract data from a URL
    """
    projects = utils.load_json_from_url(url=config.PROJECTS_URL)
    tags = utils.load_json_from_url(url=config.TAGS_URL)
    ti.xcom_push(key="projects", value=projects)
    ti.xcom_push(key="tags", value=tags)

@dag(...)
def dataops():
    extract = PythonOperator(task_id="extract", python_callable=_extract)

Warning

XComs should be used to share small metadata objects and not large data assets like this. But we're doing so only to simulate a pipeline where these assets would be prior to validation and loading.

Typically we'll use sensors to trigger workflows when a condition is met or trigger them directly from the external source via API calls, etc. Our workflows can communicate with the different platforms by establishing a connection and then using hooks to interface with the database, data warehouse, etc.

Validation

The specific process of where and how we extract our data can be bespoke but what's important is that we have a continuous integration to execute our workflows. A key aspect to trusting this continuous integration is validation at every step of the way. We'll once again use Great Expectations, as we did in our testing lesson, to validate our incoming data before transforming it.

With the Airflow concepts we've learned so far, there are many ways to use our data validation library to validate our data. Regardless of what data validation tool we use (ex. Great Expectations, TFX, AWS Deequ, etc.) we could use the BashOperator, PythonOperator, etc. to run our tests. However, Great Expectations has a Airflow Provider package to make it even easier to validate our data. This package contains a GreatExpectationsOperator which we can use to execute specific checkpoints as tasks.

Recall from our testing lesson that we used the following CLI commands to perform our data validation tests:

great_expectations checkpoint run projects
great_expectations checkpoint run tags

We can perform the same operations as Airflow tasks within our DataOps workflow, either with:

1
2
3
from airflow.operators.bash_operator import BashOperator
validate_projects = BashOperator(task_id="validate_projects", bash_command="great_expectations checkpoint run projects")
validate_tags = BashOperator(task_id="validate_tags", bash_command="great_expectations checkpoint run tags")
great_expectations checkpoint script <CHECKPOINT_NAME>

This will generate a python script under great_expectations/uncommitted/run_<CHECKPOINT_NAME>.py which you can wrap in a function to call using a PythonOperator.

pip install airflow-provider-great-expectations==0.1.1
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from great_expectations_provider.operators.great_expectations import GreatExpectationsOperator

# Validate data
validate_projects = GreatExpectationsOperator(
    task_id="validate_projects",
    checkpoint_name="projects",
    data_context_root_dir="tests/great_expectations",
    fail_task_on_validation_failure=True,
)
validate_tags = GreatExpectationsOperator(
    task_id="validate_tags",
    checkpoint_name="tags",
    data_context_root_dir="tests/great_expectations",
    fail_task_on_validation_failure=True,
)

And we want both tasks to pass so we set the fail_task_on_validation_failure parameter to True so that downstream tasks don't execute if either fail.

Note

Reminder that we previously set the following configuration in our airflow.cfg file since the output of the GreatExpectationsOperator is not JSON serializable.

# Inside airflow.cfg
enable_xcom_pickling = True

Load

Once we've validated our data, we're ready to load it into our data system (ex. data warehouse). This will be the primary system that potential downstream applications will depend on for current and future versions of data.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def _load(ti):
    """Load into data system (ex. warehouse)
    Our simple ex: load extracted data into a local file
    """
    projects = ti.xcom_pull(key="projects", task_ids=["extract"])[0]
    tags = ti.xcom_pull(key="tags", task_ids=["extract"])[0]
    utils.save_dict(d=projects, filepath=Path(config.DATA_DIR, "projects.json"))
    utils.save_dict(d=tags, filepath=Path(config.DATA_DIR, "tags.json"))

@dag(...)
def dataops():
    ...
    load = PythonOperator(task_id="load", python_callable=_load)

Transform

Once we have validated and loaded our data, we're ready to transform it. Our DataOps workflows are not specific to any particular downstream consumer so the transformation must be globally relevant (ex. cleaning missing date, aggregation, etc.). We have a wide variety of Operators to choose from depending on the tools we're using for compute (ex. Python, Spark, DBT, etc.). Many of these options have the advantage of directly performing the transformations in our data warehouse.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def _transform(ti):
    """Transform (ex. using DBT inside DWH)
    Our simple ex: using pandas to remove missing data samples
    """
    projects = ti.xcom_pull(key="projects", task_ids=["extract"])[0]
    df = pd.DataFrame(projects)
    df = df[df.tag.notnull()]  # drop rows w/ no tag
    utils.save_dict(d=df.to_dict(orient="records"), filepath=Path(config.DATA_DIR, "projects.json"))

@dag(...)
def dataops():
    ...
    transform = PythonOperator(task_id="transform", python_callable=_transform)
    validate_transforms = GreatExpectationsOperator(
        task_id="validate_transforms",
        checkpoint_name="projects",
        data_context_root_dir="tests/great_expectations",
        fail_task_on_validation_failure=True,
    )

MLOps

Once we have our data prepared, we're ready to create one of the many downstream applications that will consume it. We'll set up our MLOps pipeline inside our airflow/dags/workflows.py script:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# airflow/dags/workflows.py
@dag(
    dag_id="MLOps",
    description="MLOps tasks.",
    default_args=default_args,
    schedule_interval=None,
    start_date=days_ago(2),
    tags=["mlops"],
)
def mlops():
    pass
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# Define DAG
(
    prepare
    >> validate_prepared_data
    >> optimize
    >> train
    >> offline_evaluation
    >> online_evaluation
    >> [deploy, inspect]
)
dataops workflow

Prepare

Before kicking off any experimentation, there may be task specific data preparation that may need to happen. This is different from the data transformation during the DataOps workflow because these global transformations are specific to this task. In our case, we incorporate additional labeling constraints to simplify our classification task:

1
2
3
4
5
prepare = PythonOperator(
    task_id="prepare",
    python_callable=main.label_data,
    op_kwargs={"args_fp": Path(config.CONFIG_DIR, "args.json")},
)

And similarly to some of the DataOps tasks, we'll validate any changes we applied to our data:

1
2
3
4
5
6
validate_prepared_data = GreatExpectationsOperator(
        task_id="validate_prepared_data",
        checkpoint_name="labeled_projects",
        data_context_root_dir="tests/great_expectations",
        fail_task_on_validation_failure=True,
    )

Training

Once we have our data prepped, we can use them to optimize and train the best models. Since these tasks can require lots of compute, we would typically run this entire pipeline in a managed cluster platform which can scale up as our data and models grow.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# Optimization
optimize = PythonOperator(
    task_id="optimize",
    python_callable=main.optimize,
    op_kwargs={
        "args_fp": Path(config.CONFIG_DIR, "args.json"),
        "study_name": "optimization",
        "num_trials": 1,
    },
)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# Training
train = PythonOperator(
    task_id="train",
    python_callable=main.train_model,
    op_kwargs={
        "args_fp": Path(config.CONFIG_DIR, "args.json"),
        "experiment_name": "baselines",
        "run_name": "sgd",
    },
)

Offline evaluation

It's imperative that we evaluate our trained models so that we can trust it. We've extensively covered offline evaluation before, so here we'll talk about how the evaluation is used. Tt won't always be a simple decision where all metrics/slices are performing better than the previous version. In these scenarios, it's important to know what our main priorities are and where we can have some leeway:

  • What criteria are most important?
  • What criteria can/cannot regress?
  • How much of a regression can be tolerated?
1
2
3
4
assert precision > prev_precision  # most important, cannot regress
assert recall >= best_prev_recall - 0.03  # recall cannot regress > 3%
assert metrics["class"]["nlp"]["f1"] > prev_nlp_f1  # priority class
assert metrics["slices"]["class"]["nlp_cnn"]["f1"] > prev_nlp_cnn_f1  # priority slice

and of course, there are some components, such as behavioral testing our model's behavior of our models, that should always pass. We can in corporate this business logic into a function and determine if a newly trained version of the system is better than the current version.

1
2
3
4
5
6
def _offline_evaluation():
    """Compare offline evaluation report
    (overall, fine-grained and slice metrics).
    And ensure model behavioral tests pass.
    """
    return True
1
2
3
4
offline_evaluation = PythonOperator(
    task_id="offline_evaluation",
    python_callable=_offline_evaluation,
)

Online evaluation

Once our system has passed offline evaluation criteria, we're ready to evaluate it in the online setting. Here we are using the to execute different actions based on the results of online evaluation.

1
from airflow.operators.python import BranchPythonOperator

This Operator will execute a function whose return response will be a single (or a list) task id.

1
2
3
4
5
6
7
8
9
def _online_evaluation():
    """Run online experiments (AB, shadow, canary) to
    determine if new system should replace the current.
    """
    passed = True
    if passed:
        return "deploy"
    else:
        return "inspect"
1
2
3
4
online_evaluation = BranchPythonOperator(
    task_id="online_evaluation",
    python_callable=_online_evaluation,
)

The returning task ids can correspond to tasks that are simply used to direct the workflow towards a certain set of tasks based on upstream results. In our case, we want to deploy the improved model or inspect it if it failed online evaluation requirements.

Deploy

If our model passed our evaluation criteria then we can deploy and serve our model. Again, there are many different options here such as using our CI/CD Git workflows to deploy the model wrapped as a scalable microservice or for more streamlined deployments, we can use a purpose-build model server to seamlessly inspect, update, serve, rollback, etc. multiple versions of models.

1
2
3
4
deploy = BashOperator(
    task_id="deploy",
    bash_command="echo update model endpoint w/ new artifacts",
)

Continual learning

The DataOps and MLOps workflows connect to create an ML system that's capable of continually learning. Such a system will guide us with when to update, what exactly to update and how to update it (easily).

We use the word continual (repeat with breaks) instead of continuous (repeat without interruption / intervention) because we're not trying to create a system that will automatically update with new incoming data without human intervention.

Monitoring

Our production system is live and monitored. When an event of interest occurs (ex. drift), one of several events needs to be triggered:

  • continue: with the currently deployed model without any updates. However, an alert was raised so it should analyzed later to reduce false positive alerts.
  • improve: by retraining the model to avoid performance degradation causes by meaningful drift (data, target, concept, etc.).
  • inspect: to make a decision. Typically expectations are reassessed, schemas are reevaluated for changes, slices are reevaluated, etc.
  • rollback: to a previous version of the model because of an issue with the current deployment. Typically these can be avoided using robust deployment strategies (ex. dark canary).

Retraining

If we need to improve on the existing version of the model, it's not just the matter of fact of rerunning the model creation workflow on the new dataset. We need to carefully compose the training data in order to avoid issues such as catastrophic forgetting (forget previously learned patterns when presented with new data).

  • labeling: new incoming data may need to be properly labeled before being used (we cannot just depend on proxy labels).
  • active learning: we may not be able to explicitly label every single new data point so we have to leverage active learning workflows to complete the labeling process.
  • QA: quality assurance workflows to ensure that labeling is accurate, especially for known false positives/negatives and historically poorly performing slices of data.
  • augmentation: increasing our training set with augmented data that's representative of the original dataset.
  • sampling: upsampling and downsampling to address imbalanced data slices.
  • evaluation: creation of an evaluation dataset that's representative of what the model will encounter once deployed.

Once we have the proper dataset for retraining, we can kickoff the workflows to update our system!

References


To cite this lesson, please use:

1
2
3
4
5
6
@article{madewithml,
    author       = {Goku Mohandas},
    title        = { Orchestration - Made With ML },
    howpublished = {\url{https://madewithml.com/}},
    year         = {2021}
}