ユニファ開発者ブログ

ユニファ株式会社プロダクトデベロップメント本部メンバーによるブログです。

PythonでS3とSageMaker RuntimeをMockする

こんにちは、データエンジニアリングチームの宮崎です。

最近はコロナが落ち着いてきたので、出社するようにしていますが、ずっとリモートワークだったので会社に着ていく服が無くて困っています。

さて、機械学習モデルをサービスで運用するためには、ただモデルをデプロイするだけでなく、アプリから利用するために前処理や後処理を実装する必要があります。 これらの処理が正しく実装されているか確認するためには、単体試験のテストコードを書く必要がありますが、その際、機械学習モデルの動きを擬似するためのモックを準備しなければなりません。

そこで今回は Amazon SageMakerにデプロイしたモデルに対し、Pythonから推論するコードをテストする際のモックについて検討したいと思います。

テスト対象のコード

今回は以下のコードをテストしたいと思います。

処理の内容としては、S3から画像など入力データをダウンロードし、Base64にエンコードした後、SageMakerの推論エンドポイントにリクエストを送ります。推論結果が得られたら、それを返すコードとなっています。

import json
import base64
import boto3


def predict(s3_bucket, s3_key):
    session = boto3.Session()
    s3_client = session.client('s3')
    sagemaker_client = session.client('sagemaker-runtime')

    response = s3_client.get_object(Bucket=s3_bucket, Key=s3_key)
    data = response['Body'].read()
    encoded = base64.b64encode(data).decode('utf-8')
    response = sagemaker_client.invoke_endpoint(
        EndpointName='predictor',
        Body=json.dumps({'inputs': [encoded]}),
        ContentType='application/json'
    )
    result = json.load(response['Body'])
    return result

モックの方針

このコードをテストするためにはS3とSageMaker Runtimeのモックを作成する必要があります。

Pythonでは AWS SDK のモックライブラリとして moto があります。 motoは非常に便利で、様々なAWSサービスのモックが用意されており、S3のモックも使用することができます。しかし、残念ながらSageMaker Runtimeはサポートされていません。 これはSageMakerの推論エンドポイントはデプロイされるモデルによって応答が変わるため、共通的なモックを実装することが難しいためと思われます。

そこで、SageMaker Runtimeについては自前でモックを準備したいと思います。 幸いにもSageMaker Python SDK内のテストコードにSageMaker Runtimeのテストコードが実装されており、こちらを参考にします。

以上より、motoのS3と自前のSageMaker Runtimeモックを組み合わせる、という方針で準備しようと思います。

テストコード

今回は pytest を使用するテストコードを作成します。

test_predictpredictのテスト関数となっており、引数の s3 は pytest.fixture で、motoの mock_s3 を使用してバケットを準備します。 さらに mocker を使用して、boto3.session.Session.client にパッチを当て、SageMaker Runtimeのクライアント呼び出し時はモックするクライアントに差し替えます。

def test_predict(mocker, s3):
    mocker.patch('boto3.session.Session.client',
                 side_effect=mock_client)

    s3_bucket = 'my_bucket'
    s3_key = 'my_data'
    result = predict(s3_bucket, s3_key)
    assert result == {'predictions': [0, 1, 2]}

S3のモック

motoの mock_s3 を使用してモック用のS3バケットを作成し、データを格納します。 なお、最新のmotoだとリージョンに us-east-1 を指定しないとエラーになってしまうため、注意してください。

また、pytest.fixture 内でmotoを使用する際は return ではなく yield にする必要があります。

import os
import io
import pytest
import boto3
from moto import mock_s3

@pytest.fixture
def s3():
    os.environ['AWS_DEFAULT_REGION'] = 'us-east-1'
    s3_bucket = 'my_bucket'
    s3_key = 'my_data'

    with mock_s3():
        session = boto3.Session()
        s3_client = session.client('s3')
        s3_client.create_bucket(Bucket=s3_bucket)
        data = b'abc'
        s3_client.put_object(
            Bucket=s3_bucket, Key=s3_key, Body=data)

        yield

SageMaker Runtimeのモック

続いてSageMaker Runtimeをモックします。 boto3.client を呼ばれた際に引数が sagemaker-runtime であった場合はモック関数を返し、それ以外は通常の boto3.client を返します。これによって、引数がs3の場合は裏側でmotoによりモックされます。

ポイントは、モックの無限ループにならないように、boto3.clientを mock_method_factoryの外で取得している点です。

from mock import Mock

def mock_method_factory(client):
    def _mock_client(service_name, *args, **kwargs):
        if service_name == 'sagemaker-runtime':
            return sagemaker_runtime_mock()
        else:
            return client(boto3.Session(), service_name)

    def sagemaker_runtime_mock():
        sagemaker_runtime_client = Mock(name='sagemaker_runtime')
        response_body = json.dumps({
            'predictions': [0, 1, 2]
        }).encode('utf-8')
        return_value = {
            'ContentType': 'application/json',
            'Body': io.BytesIO(response_body),
        }
        sagemaker_runtime_client.invoke_endpoint = Mock(
            name='invoke_endpoint', return_value=return_value
        )
        return sagemaker_runtime_client

    return _mock_client


mock_client = mock_method_factory(client=boto3.session.Session.client)

テストの実施

テスト対象のコードを predictor.py、テストコードとモックコードを test_predictor.pyに書き、以下のように配置します。

.
├── predictor.py
└── test_predictor.py

それではpytestを使用して、テストコードを走らせてみましょう!

% pytest test_predictor.py -p no:warnings
=========================== test session starts ============================
platform darwin -- Python 3.9.6, pytest-6.2.5, py-1.10.0, pluggy-1.0.0
rootdir: /path/to/dir
plugins: anyio-3.3.0, mock-3.6.1, typeguard-2.12.1, cov-3.0.0
collected 1 item                                                           

test_predictor.py .                                                  [100%]

============================ 1 passed in 0.63s =============================

無事にテストをパスすることができました!

まとめ

今回は SageMaker RuntimeとS3をモックしたテストコードを書いてみました。 モックを書くのは苦手でいつもハマってしまいます。 もし、より良い書き方がありましたら、コメントいただけると幸いです。

参考資料

本記事を書くにあたって、以下のコードを参考にいたしました。


ユニファで一緒に働く仲間を募集しています!

unifa-e.com