こんにちは、データエンジニアリングチームの宮崎です。
最近はコロナが落ち着いてきたので、出社するようにしていますが、ずっとリモートワークだったので会社に着ていく服が無くて困っています。
さて、機械学習モデルをサービスで運用するためには、ただモデルをデプロイするだけでなく、アプリから利用するために前処理や後処理を実装する必要があります。 これらの処理が正しく実装されているか確認するためには、単体試験のテストコードを書く必要がありますが、その際、機械学習モデルの動きを擬似するためのモックを準備しなければなりません。
そこで今回は Amazon SageMakerにデプロイしたモデルに対し、Pythonから推論するコードをテストする際のモックについて検討したいと思います。
テスト対象のコード
今回は以下のコードをテストしたいと思います。
処理の内容としては、S3から画像など入力データをダウンロードし、Base64にエンコードした後、SageMakerの推論エンドポイントにリクエストを送ります。推論結果が得られたら、それを返すコードとなっています。
import json import base64 import boto3 def predict(s3_bucket, s3_key): session = boto3.Session() s3_client = session.client('s3') sagemaker_client = session.client('sagemaker-runtime') response = s3_client.get_object(Bucket=s3_bucket, Key=s3_key) data = response['Body'].read() encoded = base64.b64encode(data).decode('utf-8') response = sagemaker_client.invoke_endpoint( EndpointName='predictor', Body=json.dumps({'inputs': [encoded]}), ContentType='application/json' ) result = json.load(response['Body']) return result
モックの方針
このコードをテストするためにはS3とSageMaker Runtimeのモックを作成する必要があります。
Pythonでは AWS SDK のモックライブラリとして moto があります。 motoは非常に便利で、様々なAWSサービスのモックが用意されており、S3のモックも使用することができます。しかし、残念ながらSageMaker Runtimeはサポートされていません。 これはSageMakerの推論エンドポイントはデプロイされるモデルによって応答が変わるため、共通的なモックを実装することが難しいためと思われます。
そこで、SageMaker Runtimeについては自前でモックを準備したいと思います。 幸いにもSageMaker Python SDK内のテストコードにSageMaker Runtimeのテストコードが実装されており、こちらを参考にします。
以上より、motoのS3と自前のSageMaker Runtimeモックを組み合わせる、という方針で準備しようと思います。
テストコード
今回は pytest
を使用するテストコードを作成します。
test_predict
がpredict
のテスト関数となっており、引数の s3
は pytest.fixture で、motoの mock_s3 を使用してバケットを準備します。
さらに mocker を使用して、boto3.session.Session.client
にパッチを当て、SageMaker Runtimeのクライアント呼び出し時はモックするクライアントに差し替えます。
def test_predict(mocker, s3): mocker.patch('boto3.session.Session.client', side_effect=mock_client) s3_bucket = 'my_bucket' s3_key = 'my_data' result = predict(s3_bucket, s3_key) assert result == {'predictions': [0, 1, 2]}
S3のモック
motoの mock_s3
を使用してモック用のS3バケットを作成し、データを格納します。
なお、最新のmotoだとリージョンに us-east-1
を指定しないとエラーになってしまうため、注意してください。
また、pytest.fixture 内でmotoを使用する際は return
ではなく yield
にする必要があります。
import os import io import pytest import boto3 from moto import mock_s3 @pytest.fixture def s3(): os.environ['AWS_DEFAULT_REGION'] = 'us-east-1' s3_bucket = 'my_bucket' s3_key = 'my_data' with mock_s3(): session = boto3.Session() s3_client = session.client('s3') s3_client.create_bucket(Bucket=s3_bucket) data = b'abc' s3_client.put_object( Bucket=s3_bucket, Key=s3_key, Body=data) yield
SageMaker Runtimeのモック
続いてSageMaker Runtimeをモックします。 boto3.client を呼ばれた際に引数が sagemaker-runtime
であった場合はモック関数を返し、それ以外は通常の boto3.client を返します。これによって、引数がs3
の場合は裏側でmotoによりモックされます。
ポイントは、モックの無限ループにならないように、boto3.clientを mock_method_factory
の外で取得している点です。
from mock import Mock def mock_method_factory(client): def _mock_client(service_name, *args, **kwargs): if service_name == 'sagemaker-runtime': return sagemaker_runtime_mock() else: return client(boto3.Session(), service_name) def sagemaker_runtime_mock(): sagemaker_runtime_client = Mock(name='sagemaker_runtime') response_body = json.dumps({ 'predictions': [0, 1, 2] }).encode('utf-8') return_value = { 'ContentType': 'application/json', 'Body': io.BytesIO(response_body), } sagemaker_runtime_client.invoke_endpoint = Mock( name='invoke_endpoint', return_value=return_value ) return sagemaker_runtime_client return _mock_client mock_client = mock_method_factory(client=boto3.session.Session.client)
テストの実施
テスト対象のコードを predictor.py
、テストコードとモックコードを test_predictor.py
に書き、以下のように配置します。
. ├── predictor.py └── test_predictor.py
それではpytest
を使用して、テストコードを走らせてみましょう!
% pytest test_predictor.py -p no:warnings =========================== test session starts ============================ platform darwin -- Python 3.9.6, pytest-6.2.5, py-1.10.0, pluggy-1.0.0 rootdir: /path/to/dir plugins: anyio-3.3.0, mock-3.6.1, typeguard-2.12.1, cov-3.0.0 collected 1 item test_predictor.py . [100%] ============================ 1 passed in 0.63s =============================
無事にテストをパスすることができました!
まとめ
今回は SageMaker RuntimeとS3をモックしたテストコードを書いてみました。 モックを書くのは苦手でいつもハマってしまいます。 もし、より良い書き方がありましたら、コメントいただけると幸いです。
参考資料
本記事を書くにあたって、以下のコードを参考にいたしました。
- pytest と moto で優勝する - サーバーワークスエンジニアブログ
- sagemaker-python-sdk/test_predictor.py at master · aws/sagemaker-python-sdk · GitHub
- python - Max recursion depth while trying to mock instance method - Stack Overflow
ユニファで一緒に働く仲間を募集しています!