Published on

MLflow Tracking Serverを​動かす: AppEngine FE + Cloud IAP

GKE (+ Ingress) もしくは App Engine Flexible Environment に加え Cloud IAP を利用すると、限定公開の MLflow Tracking Server を楽に構成できる。 この記事では、構成が容易な後者の AppEngine FE を利用した方法を紹介する。

目的と前提

  • Cloud IAP を利用して MLflow Tracking を楽にそこそこ安全に(not 安価に)動かす。
  • MLflow Tracking のバックエンド DB は CloudSQL、Artifact store は GCS とする。
  • MLflow 1.9.1 で確認。AppEngine FE のためのコンテナイメージの Python は 3.6.x。

今のところ自分の所属するプロダクトでは、Dataflow の GKE クラスタしか動いていないので、今回は GKE ではなく App Engine FE で済ませた。

※ここでの「セキュア」という表現について

パブリックなインターネットに晒すにあたり、HTTPS 対応と IAM による認証ができることが最低限の条件だとする。

ここでは扱わない内容と他の方法

  • GCP を使わない方法
  • GKE (Ingress) + Cloud IAP でやる方法
  • MLflow 自体に手を入れる方法
  • Cloud Run + Cloud Endpoints を使う方法
    • Cloud IAP は今のところ CloudRun には対応していないので、Endpoints を利用することになりそう。
    • Endpoints (Extensive Service Proxy beta 2 が CloudRun では使える。Envoy ベースの Proxy)
      • OpenAPI の定義があれば、よしなにしてくれるらしい
    • が、MLflow の REST API に swagger.yaml はない様子で、ProcolBuffers で定義されているだけだった ( mlflow/protos)

本編

サービスアカウントを作成する。以下のようなロールがあれば十分かも。

# Cloud IAPを利用する際には必須
IAP-secured Web App User

# BackendをCloud SQL, Artifact StorageをGCSにしている場合
Cloud SQL Client
Storage Object Creator
Storage Object Viewer

# アプリで使うRoleは適宜追加する
BigQuery User

環境変数の設定

# Service Account `[email protected]`
GOOGLE_APPLICATION_CREDENTIALS=service_account_key.json

クライアントを利用するときに、Cloud IAP の OAuth2 Client ID とそれに対応した Token を、Service Account の権限で取得する。

Cloud IAP settings
OAuth2 Client ID
import os, sys
from google.oauth2 import id_token
from google.auth.transport.requests import Request as AuthRequest
import mlflow

cid = "xxxxxxxxxxxxx.apps.googleusercontent.com"
os.environ["MLFLOW_TRACKING_TOKEN"] = id_token.fetch_id_token(AuthRequest(), cid)
mlflow.set_tracking_uri("https://mlflow-dot-the-project.appspot.com/")

自分の場合、Optuna や LightGBM などの Callback 内で MLflowClient を利用することが多いので、以下のような関数を MLflow Tracking API をコールする前に実行して、MLFLOW_TRACKING_TOKEN を更新するようにしている。

def authorize_mlflow(oauth2_client_id: str = None) -> None:
    """Set valid service-account path to 'GOOGLE_APPLICATION_CREDENTIALS' envvar """
    try:
        os.environ["MLFLOW_TRACKING_TOKEN"] = id_token.fetch_id_token(
            AuthRequest(), oauth2_client_id or os.environ.get("MLFLOW_OAUTH2_CLIENT_ID", "")
        )
    except GoogleAuthError as e:
        logger.debug(e)
        logger.warning("OAuth2 token authentication error")
    except Exception as e:
        logger.debug(e)
        logger.warning("Continue without authentication")

仕組み

MLflow 1.9 時点では、認証方式として BASIC 認証と Bearer Token による認証が利用できる。 公式ドキュメント にあるとおり、これらは MLFLOW_ Prefix の環境変数に与えることで利用できる。

上記の例で、MLflow Tracking Server の API をコールするたびに自前で Token を取得しているのは、MLflow 側に Token の更新処理が実装されていないため。

付録

app.yaml の例

  • liveness_check, readiness_check には、MLflow の /health エンドポイントが使える。
  • バックエンド DB として CloudSQL のインスタンスを指定できる。
runtime: custom
env: flex
service: mlflow
skip_files:
  - service_account.json
  - ^.*\.venv
  - ^.*\.env
  - ^.*\.terraform
  - ^.*tfvers.*
  - ^.*\.tf

entrypoint: ./entrypoint.sh

liveness_check:
  path: '/health'
  check_interval_sec: 30
  timeout_sec: 4
  failure_threshold: 2
  success_threshold: 2

readiness_check:
  path: '/health'
  check_interval_sec: 5
  timeout_sec: 4
  failure_threshold: 2
  success_threshold: 2
  app_start_timeout_sec: 60

beta_settings:
  cloud_sql_instances: { INSTANCE_CONNECTION_NAME }

resources:
  cpu: 2
  memory_gb: 4
  disk_size_gb: 10

manual_scaling:
  instances: 1

env_variables:
  DB_URI: mysql://{DB_USER}:{PASSWORD}/{DATABASE}?unix_socket=/cloudsql/{INSTANCE_CONNECTION_NAME}
  ARTIFACT_ROOT: { GCS_BUCKET }

Dockerfile

  • AppEngine FE では、app.yaml と同じディレクトリにある Dockerfile から、ランタイムイメージを Cloud Build でビルドして利用するので、これも必要。
  • 下記の段階では、python3 コマンドは Python 3.6 だった。
FROM gcr.io/google-appengine/python:2020-06-17-111334

RUN apt update && \
  apt install -y --no-install-recommends mysql-client libmysqlclient-dev python3-dev

ENV PYTHONFAULTHANDLER=1 \
  PYTHONUNBUFFERED=1 \
  PYTHONHASHSEED=random \
  # pip:
  PIP_NO_CACHE_DIR=on \
  PIP_DISABLE_PIP_VERSION_CHECK=on \
  PIP_DEFAULT_TIMEOUT=100

ARG MLFLOW_VERSION=1.9.1
RUN echo "Installing MLFlow ${MLFLOW_VERSION}"
RUN pip3 install mlflow[extras]==${MLFLOW_VERSION} mysqlclient

WORKDIR /mlflow
COPY ./entrypoint.sh /mlflow/
RUN chmod +x entrypoint.sh

EXPOSE 80 5000 8080
ENTRYPOINT [ "./entrypoint.sh" ]

entrypoint.sh

  • AppEngine で動かすので、8080 番ポートを使用する。
#!/bin/bash
HOST=${MLFLOW_TRACKING_HOST:-0.0.0.0}
PORT=${PORT:-8080}

sleep 5s
mlflow db upgrade "${DB_URI}"
mlflow server \
  --backend-store-uri "${DB_URI}" \
  --default-artifact-root "${ARTIFACT_ROOT}" \
  --host "${HOST}" --port "${PORT}"