Best Python code snippet using autotest_python
test_sqlalchemy_store.py
Source:test_sqlalchemy_store.py
1import os2import unittest3import mock4import tempfile5import uuid6import mlflow7import mlflow.db8import mlflow.store.db.base_sql_model9from mlflow.entities.model_registry import RegisteredModel, ModelVersion, \10 RegisteredModelTag, ModelVersionTag11from mlflow.exceptions import MlflowException12from mlflow.protos.databricks_pb2 import ErrorCode, RESOURCE_DOES_NOT_EXIST, \13 INVALID_PARAMETER_VALUE, RESOURCE_ALREADY_EXISTS14from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore15from tests.helper_functions import random_str16DB_URI = 'sqlite:///'17class TestSqlAlchemyStoreSqlite(unittest.TestCase):18 def _get_store(self, db_uri=''):19 return SqlAlchemyStore(db_uri)20 def setUp(self):21 self.maxDiff = None # print all differences on assert failures22 fd, self.temp_dbfile = tempfile.mkstemp()23 # Close handle immediately so that we can remove the file later on in Windows24 os.close(fd)25 self.db_url = "%s%s" % (DB_URI, self.temp_dbfile)26 self.store = self._get_store(self.db_url)27 def tearDown(self):28 mlflow.store.db.base_sql_model.Base.metadata.drop_all(self.store.engine)29 os.remove(self.temp_dbfile)30 def _rm_maker(self, name, tags=None):31 return self.store.create_registered_model(name, tags)32 def _mv_maker(self, name, source="path/to/source", run_id=uuid.uuid4().hex, tags=None):33 return self.store.create_model_version(name, source, run_id, tags)34 def _extract_latest_by_stage(self, latest_versions):35 return {mvd.current_stage: mvd.version for mvd in latest_versions}36 def test_create_registered_model(self):37 name = random_str() + "abCD"38 rm1 = self._rm_maker(name)39 self.assertEqual(rm1.name, name)40 # error on duplicate41 with self.assertRaises(MlflowException) as exception_context:42 self._rm_maker(name)43 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS)44 # slightly different name is ok45 for name2 in [name + "extra", name.lower(), name.upper(), name + name]:46 rm2 = self._rm_maker(name2)47 self.assertEqual(rm2.name, name2)48 # test create model with tags49 name2 = random_str() + "tags"50 tags = [RegisteredModelTag("key", "value"),51 RegisteredModelTag("anotherKey", "some other value")]52 rm2 = self._rm_maker(name2, tags)53 rmd2 = self.store.get_registered_model(name2)54 self.assertEqual(rm2.name, name2)55 self.assertEqual(rm2.tags, {tag.key: tag.value for tag in tags})56 self.assertEqual(rmd2.name, name2)57 self.assertEqual(rmd2.tags, {tag.key: tag.value for tag in tags})58 # invalid model name will fail59 with self.assertRaises(MlflowException) as exception_context:60 self._rm_maker(None)61 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)62 with self.assertRaises(MlflowException) as exception_context:63 self._rm_maker("")64 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)65 def test_get_registered_model(self):66 name = "model_1"67 tags = [RegisteredModelTag("key", "value"),68 RegisteredModelTag("anotherKey", "some other value")]69 # use fake clock70 with mock.patch("time.time") as mock_time:71 mock_time.return_value = 123472 rm = self._rm_maker(name, tags)73 self.assertEqual(rm.name, name)74 rmd = self.store.get_registered_model(name=name)75 self.assertEqual(rmd.name, name)76 self.assertEqual(rmd.creation_timestamp, 1234000)77 self.assertEqual(rmd.last_updated_timestamp, 1234000)78 self.assertEqual(rmd.description, None)79 self.assertEqual(rmd.latest_versions, [])80 self.assertEqual(rmd.tags, {tag.key: tag.value for tag in tags})81 def test_update_registered_model(self):82 name = "model_for_update_RM"83 rm1 = self._rm_maker(name)84 rmd1 = self.store.get_registered_model(name=name)85 self.assertEqual(rm1.name, name)86 self.assertEqual(rmd1.description, None)87 # update description88 rm2 = self.store.update_registered_model(name=name, description="test model")89 rmd2 = self.store.get_registered_model(name=name)90 self.assertEqual(rm2.name, "model_for_update_RM")91 self.assertEqual(rmd2.name, "model_for_update_RM")92 self.assertEqual(rmd2.description, "test model")93 def test_rename_registered_model(self):94 original_name = "original name"95 new_name = "new name"96 self._rm_maker(original_name)97 self._mv_maker(original_name)98 self._mv_maker(original_name)99 rm = self.store.get_registered_model(original_name)100 mv1 = self.store.get_model_version(original_name, 1)101 mv2 = self.store.get_model_version(original_name, 2)102 self.assertEqual(rm.name, original_name)103 self.assertEqual(mv1.name, original_name)104 self.assertEqual(mv2.name, original_name)105 # test renaming registered model also updates its model versions106 self.store.rename_registered_model(original_name, new_name)107 rm = self.store.get_registered_model(new_name)108 mv1 = self.store.get_model_version(new_name, 1)109 mv2 = self.store.get_model_version(new_name, 2)110 self.assertEqual(rm.name, new_name)111 self.assertEqual(mv1.name, new_name)112 self.assertEqual(mv2.name, new_name)113 # test accessing the model with the old name will fail114 with self.assertRaises(MlflowException) as exception_context:115 self.store.get_registered_model(original_name)116 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)117 # test name another model with the replaced name is ok118 self._rm_maker(original_name)119 # cannot rename model to conflict with an existing model120 with self.assertRaises(MlflowException) as exception_context:121 self.store.rename_registered_model(new_name, original_name)122 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS)123 # invalid model name will fail124 with self.assertRaises(MlflowException) as exception_context:125 self.store.rename_registered_model(original_name, None)126 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)127 with self.assertRaises(MlflowException) as exception_context:128 self.store.rename_registered_model(original_name, "")129 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)130 def test_delete_registered_model(self):131 name = "model_for_delete_RM"132 self._rm_maker(name)133 self._mv_maker(name)134 rm1 = self.store.get_registered_model(name=name)135 mv1 = self.store.get_model_version(name, 1)136 self.assertEqual(rm1.name, name)137 self.assertEqual(mv1.name, name)138 # delete model139 self.store.delete_registered_model(name=name)140 # cannot get model141 with self.assertRaises(MlflowException) as exception_context:142 self.store.get_registered_model(name=name)143 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)144 # cannot update a delete model145 with self.assertRaises(MlflowException) as exception_context:146 self.store.update_registered_model(name=name, description="deleted")147 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)148 # cannot delete it again149 with self.assertRaises(MlflowException) as exception_context:150 self.store.delete_registered_model(name=name)151 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)152 # model versions are cascade deleted with the registered model153 with self.assertRaises(MlflowException) as exception_context:154 self.store.get_model_version(name, 1)155 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)156 def _list_registered_models(self, page_token=None, max_results=10):157 result = self.store.list_registered_models(max_results, page_token)158 for idx in range(len(result)):159 result[idx] = result[idx].name160 return result161 def test_list_registered_model(self):162 self._rm_maker("A")163 registered_models = self.store.list_registered_models(max_results=10, page_token=None)164 self.assertEqual(len(registered_models), 1)165 self.assertEqual(registered_models[0].name, "A")166 self.assertIsInstance(registered_models[0], RegisteredModel)167 self._rm_maker("B")168 self.assertEqual(set(self._list_registered_models()),169 set(["A", "B"]))170 self._rm_maker("BB")171 self._rm_maker("BA")172 self._rm_maker("AB")173 self._rm_maker("BBC")174 self.assertEqual(set(self._list_registered_models()),175 set(["A", "B", "BB", "BA", "AB", "BBC"]))176 # list should not return deleted models177 self.store.delete_registered_model(name="BA")178 self.store.delete_registered_model(name="B")179 self.assertEqual(set(self._list_registered_models()),180 set(["A", "BB", "AB", "BBC"]))181 def test_list_registered_model_paginated_last_page(self):182 rms = [self._rm_maker(f"RM{i:03}").name for i in range(50)]183 # test flow with fixed max_results184 returned_rms = []185 result = self._list_registered_models(page_token=None, max_results=25)186 returned_rms.extend(result)187 while result.token:188 result = self._list_registered_models(page_token=result.token, max_results=25)189 self.assertEqual(len(result), 25)190 returned_rms.extend(result)191 self.assertEqual(result.token, None)192 self.assertEqual(set(rms), set(returned_rms))193 def test_list_registered_model_paginated_returns_in_correct_order(self):194 rms = [self._rm_maker(f"RM{i:03}").name for i in range(50)]195 # test that pagination will return all valid results in sorted order196 # by name ascending197 result = self._list_registered_models(max_results=5)198 self.assertNotEqual(result.token, None)199 self.assertEqual(result, rms[0:5])200 result = self._list_registered_models(page_token=result.token, max_results=10)201 self.assertNotEqual(result.token, None)202 self.assertEqual(result, rms[5:15])203 result = self._list_registered_models(page_token=result.token, max_results=20)204 self.assertNotEqual(result.token, None)205 self.assertEqual(result, rms[15:35])206 result = self._list_registered_models(page_token=result.token, max_results=100)207 # assert that page token is None208 self.assertEqual(result.token, None)209 self.assertEqual(result, rms[35:])210 def test_list_registered_model_paginated_errors(self):211 rms = [self._rm_maker(f"RM{i:03}").name for i in range(50)]212 # test that providing a completely invalid page token throws213 with self.assertRaises(MlflowException) as exception_context:214 self._list_registered_models(page_token="evilhax", max_results=20)215 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)216 # test that providing too large of a max_results throws217 with self.assertRaises(MlflowException) as exception_context:218 self._list_registered_models(page_token="evilhax", max_results=1e15)219 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)220 self.assertIn("Invalid value for request parameter max_results",221 exception_context.exception.message)222 # list should not return deleted models223 self.store.delete_registered_model(name=f"RM{0:03}")224 self.assertEqual(set(self._list_registered_models(max_results=100)),225 set(rms[1:]))226 def test_get_latest_versions(self):227 name = "test_for_latest_versions"228 self._rm_maker(name)229 rmd1 = self.store.get_registered_model(name=name)230 self.assertEqual(rmd1.latest_versions, [])231 mv1 = self._mv_maker(name)232 self.assertEqual(mv1.version, 1)233 rmd2 = self.store.get_registered_model(name=name)234 self.assertEqual(self._extract_latest_by_stage(rmd2.latest_versions), {"None": 1})235 # add a bunch more236 mv2 = self._mv_maker(name)237 self.assertEqual(mv2.version, 2)238 self.store.transition_model_version_stage(239 name=mv2.name, version=mv2.version, stage="Production",240 archive_existing_versions=False)241 mv3 = self._mv_maker(name)242 self.assertEqual(mv3.version, 3)243 self.store.transition_model_version_stage(name=mv3.name, version=mv3.version,244 stage="Production",245 archive_existing_versions=False)246 mv4 = self._mv_maker(name)247 self.assertEqual(mv4.version, 4)248 self.store.transition_model_version_stage(249 name=mv4.name, version=mv4.version, stage="Staging",250 archive_existing_versions=False)251 # test that correct latest versions are returned for each stage252 rmd4 = self.store.get_registered_model(name=name)253 self.assertEqual(self._extract_latest_by_stage(rmd4.latest_versions),254 {"None": 1, "Production": 3, "Staging": 4})255 # delete latest Production, and should point to previous one256 self.store.delete_model_version(name=mv3.name, version=mv3.version)257 rmd5 = self.store.get_registered_model(name=name)258 self.assertEqual(self._extract_latest_by_stage(rmd5.latest_versions),259 {"None": 1, "Production": 2, "Staging": 4})260 def test_set_registered_model_tag(self):261 name1 = "SetRegisteredModelTag_TestMod"262 name2 = "SetRegisteredModelTag_TestMod 2"263 initial_tags = [RegisteredModelTag("key", "value"),264 RegisteredModelTag("anotherKey", "some other value")]265 self._rm_maker(name1, initial_tags)266 self._rm_maker(name2, initial_tags)267 new_tag = RegisteredModelTag("randomTag", "not a random value")268 self.store.set_registered_model_tag(name1, new_tag)269 rm1 = self.store.get_registered_model(name=name1)270 all_tags = initial_tags + [new_tag]271 self.assertEqual(rm1.tags, {tag.key: tag.value for tag in all_tags})272 # test overriding a tag with the same key273 overriding_tag = RegisteredModelTag("key", "overriding")274 self.store.set_registered_model_tag(name1, overriding_tag)275 all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag]276 rm1 = self.store.get_registered_model(name=name1)277 self.assertEqual(rm1.tags, {tag.key: tag.value for tag in all_tags})278 # does not affect other models with the same key279 rm2 = self.store.get_registered_model(name=name2)280 self.assertEqual(rm2.tags, {tag.key: tag.value for tag in initial_tags})281 # can not set tag on deleted (non-existed) registered model282 self.store.delete_registered_model(name1)283 with self.assertRaises(MlflowException) as exception_context:284 self.store.set_registered_model_tag(name1, overriding_tag)285 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)286 # test cannot set tags that are too long287 long_tag = RegisteredModelTag("longTagKey", "a" * 5001)288 with self.assertRaises(MlflowException) as exception_context:289 self.store.set_registered_model_tag(name2, long_tag)290 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)291 # test can set tags that are somewhat long292 long_tag = RegisteredModelTag("longTagKey", "a" * 4999)293 self.store.set_registered_model_tag(name2, long_tag)294 # can not set invalid tag295 with self.assertRaises(MlflowException) as exception_context:296 self.store.set_registered_model_tag(name2, RegisteredModelTag(key=None, value=""))297 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)298 # can not use invalid model name299 with self.assertRaises(MlflowException) as exception_context:300 self.store.set_registered_model_tag(None, RegisteredModelTag(key="key", value="value"))301 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)302 def test_delete_registered_model_tag(self):303 name1 = "DeleteRegisteredModelTag_TestMod"304 name2 = "DeleteRegisteredModelTag_TestMod 2"305 initial_tags = [RegisteredModelTag("key", "value"),306 RegisteredModelTag("anotherKey", "some other value")]307 self._rm_maker(name1, initial_tags)308 self._rm_maker(name2, initial_tags)309 new_tag = RegisteredModelTag("randomTag", "not a random value")310 self.store.set_registered_model_tag(name1, new_tag)311 self.store.delete_registered_model_tag(name1, "randomTag")312 rm1 = self.store.get_registered_model(name=name1)313 self.assertEqual(rm1.tags, {tag.key: tag.value for tag in initial_tags})314 # testing deleting a key does not affect other models with the same key315 self.store.delete_registered_model_tag(name1, "key")316 rm1 = self.store.get_registered_model(name=name1)317 rm2 = self.store.get_registered_model(name=name2)318 self.assertEqual(rm1.tags, {"anotherKey": "some other value"})319 self.assertEqual(rm2.tags, {tag.key: tag.value for tag in initial_tags})320 # delete tag that is already deleted does nothing321 self.store.delete_registered_model_tag(name1, "key")322 rm1 = self.store.get_registered_model(name=name1)323 self.assertEqual(rm1.tags, {"anotherKey": "some other value"})324 # can not delete tag on deleted (non-existed) registered model325 self.store.delete_registered_model(name1)326 with self.assertRaises(MlflowException) as exception_context:327 self.store.delete_registered_model_tag(name1, "anotherKey")328 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)329 # can not delete tag with invalid key330 with self.assertRaises(MlflowException) as exception_context:331 self.store.delete_registered_model_tag(name2, None)332 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)333 # can not use invalid model name334 with self.assertRaises(MlflowException) as exception_context:335 self.store.delete_registered_model_tag(None, "key")336 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)337 def test_create_model_version(self):338 name = "test_for_update_MV"339 self._rm_maker(name)340 run_id = uuid.uuid4().hex341 with mock.patch("time.time") as mock_time:342 mock_time.return_value = 456778343 mv1 = self._mv_maker(name, "a/b/CD", run_id)344 self.assertEqual(mv1.name, name)345 self.assertEqual(mv1.version, 1)346 mvd1 = self.store.get_model_version(mv1.name, mv1.version)347 self.assertEqual(mvd1.name, name)348 self.assertEqual(mvd1.version, 1)349 self.assertEqual(mvd1.current_stage, "None")350 self.assertEqual(mvd1.creation_timestamp, 456778000)351 self.assertEqual(mvd1.last_updated_timestamp, 456778000)352 self.assertEqual(mvd1.description, None)353 self.assertEqual(mvd1.source, "a/b/CD")354 self.assertEqual(mvd1.run_id, run_id)355 self.assertEqual(mvd1.status, "READY")356 self.assertEqual(mvd1.status_message, None)357 self.assertEqual(mvd1.tags, {})358 # new model versions for same name autoincrement versions359 mv2 = self._mv_maker(name)360 mvd2 = self.store.get_model_version(name=mv2.name, version=mv2.version)361 self.assertEqual(mv2.version, 2)362 self.assertEqual(mvd2.version, 2)363 # create model version with tags return model version entity with tags364 tags = [ModelVersionTag("key", "value"),365 ModelVersionTag("anotherKey", "some other value")]366 mv3 = self._mv_maker(name, tags=tags)367 mvd3 = self.store.get_model_version(name=mv3.name, version=mv3.version)368 self.assertEqual(mv3.version, 3)369 self.assertEqual(mv3.tags, {tag.key: tag.value for tag in tags})370 self.assertEqual(mvd3.version, 3)371 self.assertEqual(mvd3.tags, {tag.key: tag.value for tag in tags})372 def test_update_model_version(self):373 name = "test_for_update_MV"374 self._rm_maker(name)375 mv1 = self._mv_maker(name)376 mvd1 = self.store.get_model_version(name=mv1.name, version=mv1.version)377 self.assertEqual(mvd1.name, name)378 self.assertEqual(mvd1.version, 1)379 self.assertEqual(mvd1.current_stage, "None")380 # update stage381 self.store.transition_model_version_stage(name=mv1.name, version=mv1.version,382 stage="Production",383 archive_existing_versions=False)384 mvd2 = self.store.get_model_version(name=mv1.name, version=mv1.version)385 self.assertEqual(mvd2.name, name)386 self.assertEqual(mvd2.version, 1)387 self.assertEqual(mvd2.current_stage, "Production")388 self.assertEqual(mvd2.description, None)389 # update description390 self.store.update_model_version(name=mv1.name, version=mv1.version,391 description="test model version")392 mvd3 = self.store.get_model_version(name=mv1.name, version=mv1.version)393 self.assertEqual(mvd3.name, name)394 self.assertEqual(mvd3.version, 1)395 self.assertEqual(mvd3.current_stage, "Production")396 self.assertEqual(mvd3.description, "test model version")397 # only valid stages can be set398 with self.assertRaises(MlflowException) as exception_context:399 self.store.transition_model_version_stage(mv1.name, mv1.version,400 stage="unknown",401 archive_existing_versions=False)402 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)403 # stages are case-insensitive and auto-corrected to system stage names404 for stage_name in ["STAGING", "staging", "StAgInG"]:405 self.store.transition_model_version_stage(406 name=mv1.name, version=mv1.version,407 stage=stage_name, archive_existing_versions=False)408 mvd5 = self.store.get_model_version(name=mv1.name, version=mv1.version)409 self.assertEqual(mvd5.current_stage, "Staging")410 def test_transition_model_version_stage_when_archive_existing_versions_is_false(self):411 name = "model"412 self._rm_maker(name)413 mv1 = self._mv_maker(name)414 mv2 = self._mv_maker(name)415 mv3 = self._mv_maker(name)416 # test that when `archive_existing_versions` is False, transitioning a model version417 # to the inactive stages ("Archived" and "None") does not throw.418 for stage in ["Archived", "None"]:419 self.store.transition_model_version_stage(name, mv1.version, stage, False)420 self.store.transition_model_version_stage(name, mv1.version, "Staging", False)421 self.store.transition_model_version_stage(name, mv2.version, "Production", False)422 self.store.transition_model_version_stage(name, mv3.version, "Staging", False)423 mvd1 = self.store.get_model_version(name=name, version=mv1.version)424 mvd2 = self.store.get_model_version(name=name, version=mv2.version)425 mvd3 = self.store.get_model_version(name=name, version=mv3.version)426 self.assertEqual(mvd1.current_stage, "Staging")427 self.assertEqual(mvd2.current_stage, "Production")428 self.assertEqual(mvd3.current_stage, "Staging")429 self.store.transition_model_version_stage(name, mv3.version, "Production", False)430 mvd1 = self.store.get_model_version(name=name, version=mv1.version)431 mvd2 = self.store.get_model_version(name=name, version=mv2.version)432 mvd3 = self.store.get_model_version(name=name, version=mv3.version)433 self.assertEqual(mvd1.current_stage, "Staging")434 self.assertEqual(mvd2.current_stage, "Production")435 self.assertEqual(mvd3.current_stage, "Production")436 def test_transition_model_version_stage_when_archive_existing_versions_is_true(self):437 name = "model"438 self._rm_maker(name)439 mv1 = self._mv_maker(name)440 mv2 = self._mv_maker(name)441 mv3 = self._mv_maker(name)442 msg = (r"Model version transition cannot archive existing model versions "443 r"because .+ is not an Active stage. Valid stages are .+")444 # test that when `archive_existing_versions` is True, transitioning a model version445 # to the inactive stages ("Archived" and "None") throws.446 for stage in ["Archived", "None"]:447 with self.assertRaisesRegex(MlflowException, msg):448 self.store.transition_model_version_stage(name, mv1.version, stage, True)449 self.store.transition_model_version_stage(name, mv1.version, "Staging", False)450 self.store.transition_model_version_stage(name, mv2.version, "Production", False)451 self.store.transition_model_version_stage(name, mv3.version, "Staging", True)452 mvd1 = self.store.get_model_version(name=name, version=mv1.version)453 mvd2 = self.store.get_model_version(name=name, version=mv2.version)454 mvd3 = self.store.get_model_version(name=name, version=mv3.version)455 self.assertEqual(mvd1.current_stage, "Archived")456 self.assertEqual(mvd2.current_stage, "Production")457 self.assertEqual(mvd3.current_stage, "Staging")458 self.assertEqual(mvd1.last_updated_timestamp, mvd3.last_updated_timestamp)459 self.store.transition_model_version_stage(name, mv3.version, "Production", True)460 mvd1 = self.store.get_model_version(name=name, version=mv1.version)461 mvd2 = self.store.get_model_version(name=name, version=mv2.version)462 mvd3 = self.store.get_model_version(name=name, version=mv3.version)463 self.assertEqual(mvd1.current_stage, "Archived")464 self.assertEqual(mvd2.current_stage, "Archived")465 self.assertEqual(mvd3.current_stage, "Production")466 self.assertEqual(mvd2.last_updated_timestamp, mvd3.last_updated_timestamp)467 def test_delete_model_version(self):468 name = "test_for_update_MV"469 initial_tags = [ModelVersionTag("key", "value"),470 ModelVersionTag("anotherKey", "some other value")]471 self._rm_maker(name)472 mv = self._mv_maker(name, tags=initial_tags)473 mvd = self.store.get_model_version(name=mv.name, version=mv.version)474 self.assertEqual(mvd.name, name)475 self.store.delete_model_version(name=mv.name, version=mv.version)476 # cannot get a deleted model version477 with self.assertRaises(MlflowException) as exception_context:478 self.store.get_model_version(name=mv.name, version=mv.version)479 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)480 # cannot update a delete481 with self.assertRaises(MlflowException) as exception_context:482 self.store.update_model_version(mv.name, mv.version, description="deleted!")483 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)484 # cannot delete it again485 with self.assertRaises(MlflowException) as exception_context:486 self.store.delete_model_version(name=mv.name, version=mv.version)487 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)488 def test_get_model_version_download_uri(self):489 name = "test_for_update_MV"490 self._rm_maker(name)491 source_path = "path/to/source"492 mv = self._mv_maker(name, source=source_path, run_id=uuid.uuid4().hex)493 mvd1 = self.store.get_model_version(name=mv.name, version=mv.version)494 self.assertEqual(mvd1.name, name)495 self.assertEqual(mvd1.source, source_path)496 # download location points to source497 self.assertEqual(self.store.get_model_version_download_uri(name=mv.name,498 version=mv.version), source_path)499 # download URI does not change even if model version is updated500 self.store.transition_model_version_stage(501 name=mv.name, version=mv.version,502 stage="Production",503 archive_existing_versions=False)504 self.store.update_model_version(name=mv.name, version=mv.version,505 description="Test for Path")506 mvd2 = self.store.get_model_version(name=mv.name, version=mv.version)507 self.assertEqual(mvd2.source, source_path)508 self.assertEqual(self.store.get_model_version_download_uri(509 name=mv.name, version=mv.version), source_path)510 # cannot retrieve download URI for deleted model versions511 self.store.delete_model_version(name=mv.name, version=mv.version)512 with self.assertRaises(MlflowException) as exception_context:513 self.store.get_model_version_download_uri(name=mv.name, version=mv.version)514 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)515 def test_search_model_versions(self):516 # create some model versions517 name = "test_for_search_MV"518 self._rm_maker(name)519 run_id_1 = uuid.uuid4().hex520 run_id_2 = uuid.uuid4().hex521 run_id_3 = uuid.uuid4().hex522 mv1 = self._mv_maker(name=name, source="A/B", run_id=run_id_1)523 self.assertEqual(mv1.version, 1)524 mv2 = self._mv_maker(name=name, source="A/C", run_id=run_id_2)525 self.assertEqual(mv2.version, 2)526 mv3 = self._mv_maker(name=name, source="A/D", run_id=run_id_2)527 self.assertEqual(mv3.version, 3)528 mv4 = self._mv_maker(name=name, source="A/D", run_id=run_id_3)529 self.assertEqual(mv4.version, 4)530 def search_versions(filter_string):531 return [mvd.version for mvd in self.store.search_model_versions(filter_string)]532 # search using name should return all 4 versions533 self.assertEqual(set(search_versions("name='%s'" % name)), set([1, 2, 3, 4]))534 # search using run_id_1 should return version 1535 self.assertEqual(set(search_versions("run_id='%s'" % run_id_1)), set([1]))536 # search using run_id_2 should return versions 2 and 3537 self.assertEqual(set(search_versions("run_id='%s'" % run_id_2)), set([2, 3]))538 # search using source_path "A/D" should return version 3 and 4539 self.assertEqual(set(search_versions("source_path = 'A/D'")), set([3, 4]))540 # search using source_path "A" should not return anything541 self.assertEqual(len(search_versions("source_path = 'A'")), 0)542 self.assertEqual(len(search_versions("source_path = 'A/'")), 0)543 self.assertEqual(len(search_versions("source_path = ''")), 0)544 # delete mv4. search should not return version 4545 self.store.delete_model_version(name=mv4.name, version=mv4.version)546 self.assertEqual(set(search_versions("")), set([1, 2, 3]))547 self.assertEqual(set(search_versions(None)), set([1, 2, 3]))548 self.assertEqual(set(search_versions("name='%s'" % name)), set([1, 2, 3]))549 self.assertEqual(set(search_versions("source_path = 'A/D'")), set([3]))550 self.store.transition_model_version_stage(551 name=mv1.name, version=mv1.version, stage="production",552 archive_existing_versions=False553 )554 self.store.update_model_version(555 name=mv1.name, version=mv1.version, description="Online prediction model!")556 mvds = self.store.search_model_versions("run_id = '%s'" % run_id_1)557 assert 1 == len(mvds)558 assert isinstance(mvds[0], ModelVersion)559 assert mvds[0].current_stage == "Production"560 assert mvds[0].run_id == run_id_1561 assert mvds[0].source == "A/B"562 assert mvds[0].description == "Online prediction model!"563 def _search_registered_models(self,564 filter_string,565 max_results=10,566 order_by=None,567 page_token=None):568 result = self.store.search_registered_models(filter_string=filter_string,569 max_results=max_results,570 order_by=order_by,571 page_token=page_token)572 return [registered_model.name for registered_model in result], result.token573 def test_search_registered_models(self):574 # create some registered models575 prefix = "test_for_search_"576 names = [prefix + name for name in ["RM1", "RM2", "RM3", "RM4", "RM4A", "RM4a"]]577 [self._rm_maker(name) for name in names]578 # search with no filter should return all registered models579 rms, _ = self._search_registered_models(None)580 self.assertEqual(rms, names)581 # equality search using name should return exactly the 1 name582 rms, _ = self._search_registered_models(f"name='{names[0]}'")583 self.assertEqual(rms, [names[0]])584 # equality search using name that is not valid should return nothing585 rms, _ = self._search_registered_models(f"name='{names[0] + 'cats'}'")586 self.assertEqual(rms, [])587 # case-sensitive prefix search using LIKE should return all the RMs588 rms, _ = self._search_registered_models(f"name LIKE '{prefix}%'")589 self.assertEqual(rms, names)590 # case-sensitive prefix search using LIKE with surrounding % should return all the RMs591 rms, _ = self._search_registered_models(f"name LIKE '%RM%'")592 self.assertEqual(rms, names)593 # case-sensitive prefix search using LIKE with surrounding % should return all the RMs594 # _e% matches test_for_search_ , so all RMs should match595 rms, _ = self._search_registered_models(f"name LIKE '_e%'")596 self.assertEqual(rms, names)597 # case-sensitive prefix search using LIKE should return just rm4598 rms, _ = self._search_registered_models(f"name LIKE '{prefix + 'RM4A'}%'")599 self.assertEqual(rms, [names[4]])600 # case-sensitive prefix search using LIKE should return no models if no match601 rms, _ = self._search_registered_models(f"name LIKE '{prefix + 'cats'}%'")602 self.assertEqual(rms, [])603 # confirm that LIKE is not case-sensitive604 rms, _ = self._search_registered_models(f"name lIkE '%blah%'")605 self.assertEqual(rms, [])606 rms, _ = self._search_registered_models(f"name like '{prefix + 'RM4A'}%'")607 self.assertEqual(rms, [names[4]])608 # case-insensitive prefix search using ILIKE should return both rm5 and rm6609 rms, _ = self._search_registered_models(f"name ILIKE '{prefix + 'RM4A'}%'")610 self.assertEqual(rms, names[4:])611 # case-insensitive postfix search with ILIKE612 rms, _ = self._search_registered_models(f"name ILIKE '%RM4a'")613 self.assertEqual(rms, names[4:])614 # case-insensitive prefix search using ILIKE should return both rm5 and rm6615 rms, _ = self._search_registered_models(f"name ILIKE '{prefix + 'cats'}%'")616 self.assertEqual(rms, [])617 # confirm that ILIKE is not case-sensitive618 rms, _ = self._search_registered_models(f"name iLike '%blah%'")619 self.assertEqual(rms, [])620 # confirm that ILIKE works for empty query621 rms, _ = self._search_registered_models(f"name iLike '%%'")622 self.assertEqual(rms, names)623 rms, _ = self._search_registered_models(f"name ilike '%RM4a'")624 self.assertEqual(rms, names[4:])625 # cannot search by invalid comparator types626 with self.assertRaises(MlflowException) as exception_context:627 self._search_registered_models("name!=something")628 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)629 # cannot search by run_id630 with self.assertRaises(MlflowException) as exception_context:631 self._search_registered_models("run_id='%s'" % "somerunID")632 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)633 # cannot search by source_path634 with self.assertRaises(MlflowException) as exception_context:635 self._search_registered_models("source_path = 'A/D'")636 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)637 # cannot search by other params638 with self.assertRaises(MlflowException) as exception_context:639 self._search_registered_models("evilhax = true")640 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)641 # delete last registered model. search should not return the first 5642 self.store.delete_registered_model(name=names[-1])643 self.assertEqual(self._search_registered_models(None, max_results=1000), (names[:-1], None))644 # equality search using name should return no names645 self.assertEqual(self._search_registered_models(f"name='{names[-1]}'"), ([], None))646 # case-sensitive prefix search using LIKE should return all the RMs647 self.assertEqual(self._search_registered_models(f"name LIKE '{prefix}%'"),648 (names[0:5], None))649 # case-insensitive prefix search using ILIKE should return both rm5 and rm6650 self.assertEqual(self._search_registered_models(f"name ILIKE '{prefix + 'RM4A'}%'"),651 ([names[4]], None))652 def test_search_registered_model_pagination(self):653 rms = [self._rm_maker(f"RM{i:03}").name for i in range(50)]654 # test flow with fixed max_results655 returned_rms = []656 query = "name LIKE 'RM%'"657 result, token = self._search_registered_models(query, page_token=None, max_results=5)658 returned_rms.extend(result)659 while token:660 result, token = self._search_registered_models(query, page_token=token, max_results=5)661 returned_rms.extend(result)662 self.assertEqual(rms, returned_rms)663 # test that pagination will return all valid results in sorted order664 # by name ascending665 result, token1 = self._search_registered_models(query, max_results=5)666 self.assertNotEqual(token1, None)667 self.assertEqual(result, rms[0:5])668 result, token2 = self._search_registered_models(query, page_token=token1, max_results=10)669 self.assertNotEqual(token2, None)670 self.assertEqual(result, rms[5:15])671 result, token3 = self._search_registered_models(query, page_token=token2, max_results=20)672 self.assertNotEqual(token3, None)673 self.assertEqual(result, rms[15:35])674 result, token4 = self._search_registered_models(query, page_token=token3, max_results=100)675 # assert that page token is None676 self.assertEqual(token4, None)677 self.assertEqual(result, rms[35:])678 # test that providing a completely invalid page token throws679 with self.assertRaises(MlflowException) as exception_context:680 self._search_registered_models(query, page_token="evilhax", max_results=20)681 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)682 # test that providing too large of a max_results throws683 with self.assertRaises(MlflowException) as exception_context:684 self._search_registered_models(query, page_token="evilhax", max_results=1e15)685 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)686 self.assertIn("Invalid value for request parameter max_results",687 exception_context.exception.message)688 def test_search_registered_model_order_by(self):689 rms = []690 # explicitly mock the creation_timestamps because timestamps seem to be unstable in Windows691 for i in range(50):692 with mock.patch("mlflow.store.model_registry.sqlalchemy_store.now", return_value=i):693 rms.append(self._rm_maker(f"RM{i:03}").name)694 # test flow with fixed max_results and order_by (test stable order across pages)695 returned_rms = []696 query = "name LIKE 'RM%'"697 result, token = self._search_registered_models(query,698 page_token=None,699 order_by=['name DESC'],700 max_results=5)701 returned_rms.extend(result)702 while token:703 result, token = self._search_registered_models(query,704 page_token=token,705 order_by=['name DESC'],706 max_results=5)707 returned_rms.extend(result)708 # name descending should be the opposite order of the current order709 self.assertEqual(rms[::-1], returned_rms)710 # last_updated_timestamp descending should have the newest RMs first711 result, _ = self._search_registered_models(query,712 page_token=None,713 order_by=['last_updated_timestamp DESC'],714 max_results=100)715 self.assertEqual(rms[::-1], result)716 # timestamp returns same result as last_updated_timestamp717 result, _ = self._search_registered_models(query,718 page_token=None,719 order_by=['timestamp DESC'],720 max_results=100)721 self.assertEqual(rms[::-1], result)722 # last_updated_timestamp ascending should have the oldest RMs first723 result, _ = self._search_registered_models(query,724 page_token=None,725 order_by=['last_updated_timestamp ASC'],726 max_results=100)727 self.assertEqual(rms, result)728 # timestamp returns same result as last_updated_timestamp729 result, _ = self._search_registered_models(query,730 page_token=None,731 order_by=['timestamp ASC'],732 max_results=100)733 self.assertEqual(rms, result)734 # timestamp returns same result as last_updated_timestamp735 result, _ = self._search_registered_models(query,736 page_token=None,737 order_by=['timestamp'],738 max_results=100)739 self.assertEqual(rms, result)740 # name ascending should have the original order741 result, _ = self._search_registered_models(query,742 page_token=None,743 order_by=['name ASC'],744 max_results=100)745 self.assertEqual(rms, result)746 # test that no ASC/DESC defaults to ASC747 result, _ = self._search_registered_models(query,748 page_token=None,749 order_by=['last_updated_timestamp'],750 max_results=100)751 self.assertEqual(rms, result)752 with mock.patch("mlflow.store.model_registry.sqlalchemy_store.now", return_value=1):753 rm1 = self._rm_maker("MR1").name754 rm2 = self._rm_maker("MR2").name755 with mock.patch("mlflow.store.model_registry.sqlalchemy_store.now", return_value=2):756 rm3 = self._rm_maker("MR3").name757 rm4 = self._rm_maker("MR4").name758 query = "name LIKE 'MR%'"759 # test with multiple clauses760 result, _ = self._search_registered_models(query,761 page_token=None,762 order_by=['last_updated_timestamp ASC',763 'name DESC'],764 max_results=100)765 self.assertEqual([rm2, rm1, rm4, rm3], result)766 result, _ = self._search_registered_models(query,767 page_token=None,768 order_by=['timestamp ASC',769 'name DESC'],770 max_results=100)771 self.assertEqual([rm2, rm1, rm4, rm3], result)772 # confirm that name ascending is the default, even if ties exist on other fields773 result, _ = self._search_registered_models(query,774 page_token=None,775 order_by=[],776 max_results=100)777 self.assertEqual([rm1, rm2, rm3, rm4], result)778 # test default tiebreak with descending timestamps779 result, _ = self._search_registered_models(query,780 page_token=None,781 order_by=['last_updated_timestamp DESC'],782 max_results=100)783 self.assertEqual([rm3, rm4, rm1, rm2], result)784 # test timestamp parsing785 result, _ = self._search_registered_models(query,786 page_token=None,787 order_by=['timestamp\tASC'],788 max_results=100)789 self.assertEqual([rm1, rm2, rm3, rm4], result)790 result, _ = self._search_registered_models(query,791 page_token=None,792 order_by=['timestamp\r\rASC'],793 max_results=100)794 self.assertEqual([rm1, rm2, rm3, rm4], result)795 result, _ = self._search_registered_models(query,796 page_token=None,797 order_by=['timestamp\nASC'],798 max_results=100)799 self.assertEqual([rm1, rm2, rm3, rm4], result)800 result, _ = self._search_registered_models(query,801 page_token=None,802 order_by=['timestamp ASC'],803 max_results=100)804 self.assertEqual([rm1, rm2, rm3, rm4], result)805 # validate order by key is case-insensitive806 result, _ = self._search_registered_models(query,807 page_token=None,808 order_by=['timestamp asc'],809 max_results=100)810 self.assertEqual([rm1, rm2, rm3, rm4], result)811 result, _ = self._search_registered_models(query,812 page_token=None,813 order_by=['timestamp aSC'],814 max_results=100)815 self.assertEqual([rm1, rm2, rm3, rm4], result)816 result, _ = self._search_registered_models(query,817 page_token=None,818 order_by=['timestamp desc',819 'name desc'],820 max_results=100)821 self.assertEqual([rm4, rm3, rm2, rm1], result)822 result, _ = self._search_registered_models(query,823 page_token=None,824 order_by=['timestamp deSc',825 'name deSc'],826 max_results=100)827 self.assertEqual([rm4, rm3, rm2, rm1], result)828 def test_search_registered_model_order_by_errors(self):829 query = "name LIKE 'RM%'"830 # test that invalid columns throw even if they come after valid columns831 with self.assertRaises(MlflowException) as exception_context:832 self._search_registered_models(query,833 page_token=None,834 order_by=['name ASC', 'creation_timestamp DESC'],835 max_results=5)836 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)837 # test that invalid columns with random text throw even if they come after valid columns838 with self.assertRaises(MlflowException) as exception_context:839 self._search_registered_models(query,840 page_token=None,841 order_by=['name ASC',842 'last_updated_timestamp DESC blah'],843 max_results=5)844 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)845 def test_set_model_version_tag(self):846 name1 = "SetModelVersionTag_TestMod"847 name2 = "SetModelVersionTag_TestMod 2"848 initial_tags = [ModelVersionTag("key", "value"),849 ModelVersionTag("anotherKey", "some other value")]850 self._rm_maker(name1)851 self._rm_maker(name2)852 run_id_1 = uuid.uuid4().hex853 run_id_2 = uuid.uuid4().hex854 run_id_3 = uuid.uuid4().hex855 self._mv_maker(name1, "A/B", run_id_1, initial_tags)856 self._mv_maker(name1, "A/C", run_id_2, initial_tags)857 self._mv_maker(name2, "A/D", run_id_3, initial_tags)858 new_tag = ModelVersionTag("randomTag", "not a random value")859 self.store.set_model_version_tag(name1, 1, new_tag)860 all_tags = initial_tags + [new_tag]861 rm1mv1 = self.store.get_model_version(name1, 1)862 self.assertEqual(rm1mv1.tags, {tag.key: tag.value for tag in all_tags})863 # test overriding a tag with the same key864 overriding_tag = ModelVersionTag("key", "overriding")865 self.store.set_model_version_tag(name1, 1, overriding_tag)866 all_tags = [tag for tag in all_tags if tag.key != "key"] + [overriding_tag]867 rm1mv1 = self.store.get_model_version(name1, 1)868 self.assertEqual(rm1mv1.tags, {tag.key: tag.value for tag in all_tags})869 # does not affect other model versions with the same key870 rm1mv2 = self.store.get_model_version(name1, 2)871 rm2mv1 = self.store.get_model_version(name2, 1)872 self.assertEqual(rm1mv2.tags, {tag.key: tag.value for tag in initial_tags})873 self.assertEqual(rm2mv1.tags, {tag.key: tag.value for tag in initial_tags})874 # can not set tag on deleted (non-existed) model version875 self.store.delete_model_version(name1, 2)876 with self.assertRaises(MlflowException) as exception_context:877 self.store.set_model_version_tag(name1, 2, overriding_tag)878 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)879 # test cannot set tags that are too long880 long_tag = ModelVersionTag("longTagKey", "a" * 5001)881 with self.assertRaises(MlflowException) as exception_context:882 self.store.set_model_version_tag(name1, 1, long_tag)883 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)884 # test can set tags that are somewhat long885 long_tag = ModelVersionTag("longTagKey", "a" * 4999)886 self.store.set_model_version_tag(name1, 1, long_tag)887 # can not set invalid tag888 with self.assertRaises(MlflowException) as exception_context:889 self.store.set_model_version_tag(name2, 1, ModelVersionTag(key=None, value=""))890 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)891 # can not use invalid model name or version892 with self.assertRaises(MlflowException) as exception_context:893 self.store.set_model_version_tag(None, 1, ModelVersionTag(key="key", value="value"))894 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)895 with self.assertRaises(MlflowException) as exception_context:896 self.store.set_model_version_tag(name2, "I am not a version",897 ModelVersionTag(key="key", value="value"))898 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)899 def test_delete_model_version_tag(self):900 name1 = "DeleteModelVersionTag_TestMod"901 name2 = "DeleteModelVersionTag_TestMod 2"902 initial_tags = [ModelVersionTag("key", "value"),903 ModelVersionTag("anotherKey", "some other value")]904 self._rm_maker(name1)905 self._rm_maker(name2)906 run_id_1 = uuid.uuid4().hex907 run_id_2 = uuid.uuid4().hex908 run_id_3 = uuid.uuid4().hex909 self._mv_maker(name1, "A/B", run_id_1, initial_tags)910 self._mv_maker(name1, "A/C", run_id_2, initial_tags)911 self._mv_maker(name2, "A/D", run_id_3, initial_tags)912 new_tag = ModelVersionTag("randomTag", "not a random value")913 self.store.set_model_version_tag(name1, 1, new_tag)914 self.store.delete_model_version_tag(name1, 1, "randomTag")915 rm1mv1 = self.store.get_model_version(name1, 1)916 self.assertEqual(rm1mv1.tags, {tag.key: tag.value for tag in initial_tags})917 # testing deleting a key does not affect other model versions with the same key918 self.store.delete_model_version_tag(name1, 1, "key")919 rm1mv1 = self.store.get_model_version(name1, 1)920 rm1mv2 = self.store.get_model_version(name1, 2)921 rm2mv1 = self.store.get_model_version(name2, 1)922 self.assertEqual(rm1mv1.tags, {"anotherKey": "some other value"})923 self.assertEqual(rm1mv2.tags, {tag.key: tag.value for tag in initial_tags})924 self.assertEqual(rm2mv1.tags, {tag.key: tag.value for tag in initial_tags})925 # delete tag that is already deleted does nothing926 self.store.delete_model_version_tag(name1, 1, "key")927 rm1mv1 = self.store.get_model_version(name1, 1)928 self.assertEqual(rm1mv1.tags, {"anotherKey": "some other value"})929 # can not delete tag on deleted (non-existed) model version930 self.store.delete_model_version(name2, 1)931 with self.assertRaises(MlflowException) as exception_context:932 self.store.delete_model_version_tag(name2, 1, "key")933 assert exception_context.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)934 # can not delete tag with invalid key935 with self.assertRaises(MlflowException) as exception_context:936 self.store.delete_model_version_tag(name1, 2, None)937 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)938 # can not use invalid model name or version939 with self.assertRaises(MlflowException) as exception_context:940 self.store.delete_model_version_tag(None, 2, "key")941 assert exception_context.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)942 with self.assertRaises(MlflowException) as exception_context:943 self.store.delete_model_version_tag(name1, "I am not a version", "key")...
_metadata_code_details_test.py
Source:_metadata_code_details_test.py
1# Copyright 2016 gRPC authors.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tests application-provided metadata, status code, and details."""15import threading16import unittest17import grpc18from tests.unit import test_common19from tests.unit.framework.common import test_constants20from tests.unit.framework.common import test_control21_SERIALIZED_REQUEST = b'\x46\x47\x48'22_SERIALIZED_RESPONSE = b'\x49\x50\x51'23_REQUEST_SERIALIZER = lambda unused_request: _SERIALIZED_REQUEST24_REQUEST_DESERIALIZER = lambda unused_serialized_request: object()25_RESPONSE_SERIALIZER = lambda unused_response: _SERIALIZED_RESPONSE26_RESPONSE_DESERIALIZER = lambda unused_serialized_response: object()27_SERVICE = 'test.TestService'28_UNARY_UNARY = 'UnaryUnary'29_UNARY_STREAM = 'UnaryStream'30_STREAM_UNARY = 'StreamUnary'31_STREAM_STREAM = 'StreamStream'32_CLIENT_METADATA = (('client-md-key', 'client-md-key'), ('client-md-key-bin',33 b'\x00\x01'))34_SERVER_INITIAL_METADATA = (('server-initial-md-key',35 'server-initial-md-value'),36 ('server-initial-md-key-bin', b'\x00\x02'))37_SERVER_TRAILING_METADATA = (('server-trailing-md-key',38 'server-trailing-md-value'),39 ('server-trailing-md-key-bin', b'\x00\x03'))40_NON_OK_CODE = grpc.StatusCode.NOT_FOUND41_DETAILS = 'Test details!'42# calling abort should always fail an RPC, even for "invalid" codes43_ABORT_CODES = (_NON_OK_CODE, 3, grpc.StatusCode.OK)44_EXPECTED_CLIENT_CODES = (_NON_OK_CODE, grpc.StatusCode.UNKNOWN,45 grpc.StatusCode.UNKNOWN)46_EXPECTED_DETAILS = (_DETAILS, _DETAILS, '')47class _Servicer(object):48 def __init__(self):49 self._lock = threading.Lock()50 self._abort_call = False51 self._code = None52 self._details = None53 self._exception = False54 self._return_none = False55 self._received_client_metadata = None56 def unary_unary(self, request, context):57 with self._lock:58 self._received_client_metadata = context.invocation_metadata()59 context.send_initial_metadata(_SERVER_INITIAL_METADATA)60 context.set_trailing_metadata(_SERVER_TRAILING_METADATA)61 if self._abort_call:62 context.abort(self._code, self._details)63 else:64 if self._code is not None:65 context.set_code(self._code)66 if self._details is not None:67 context.set_details(self._details)68 if self._exception:69 raise test_control.Defect()70 else:71 return None if self._return_none else object()72 def unary_stream(self, request, context):73 with self._lock:74 self._received_client_metadata = context.invocation_metadata()75 context.send_initial_metadata(_SERVER_INITIAL_METADATA)76 context.set_trailing_metadata(_SERVER_TRAILING_METADATA)77 if self._abort_call:78 context.abort(self._code, self._details)79 else:80 if self._code is not None:81 context.set_code(self._code)82 if self._details is not None:83 context.set_details(self._details)84 for _ in range(test_constants.STREAM_LENGTH // 2):85 yield _SERIALIZED_RESPONSE86 if self._exception:87 raise test_control.Defect()88 def stream_unary(self, request_iterator, context):89 with self._lock:90 self._received_client_metadata = context.invocation_metadata()91 context.send_initial_metadata(_SERVER_INITIAL_METADATA)92 context.set_trailing_metadata(_SERVER_TRAILING_METADATA)93 # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the94 # request iterator.95 list(request_iterator)96 if self._abort_call:97 context.abort(self._code, self._details)98 else:99 if self._code is not None:100 context.set_code(self._code)101 if self._details is not None:102 context.set_details(self._details)103 if self._exception:104 raise test_control.Defect()105 else:106 return None if self._return_none else _SERIALIZED_RESPONSE107 def stream_stream(self, request_iterator, context):108 with self._lock:109 self._received_client_metadata = context.invocation_metadata()110 context.send_initial_metadata(_SERVER_INITIAL_METADATA)111 context.set_trailing_metadata(_SERVER_TRAILING_METADATA)112 # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the113 # request iterator.114 list(request_iterator)115 if self._abort_call:116 context.abort(self._code, self._details)117 else:118 if self._code is not None:119 context.set_code(self._code)120 if self._details is not None:121 context.set_details(self._details)122 for _ in range(test_constants.STREAM_LENGTH // 3):123 yield object()124 if self._exception:125 raise test_control.Defect()126 def set_abort_call(self):127 with self._lock:128 self._abort_call = True129 def set_code(self, code):130 with self._lock:131 self._code = code132 def set_details(self, details):133 with self._lock:134 self._details = details135 def set_exception(self):136 with self._lock:137 self._exception = True138 def set_return_none(self):139 with self._lock:140 self._return_none = True141 def received_client_metadata(self):142 with self._lock:143 return self._received_client_metadata144def _generic_handler(servicer):145 method_handlers = {146 _UNARY_UNARY:147 grpc.unary_unary_rpc_method_handler(148 servicer.unary_unary,149 request_deserializer=_REQUEST_DESERIALIZER,150 response_serializer=_RESPONSE_SERIALIZER),151 _UNARY_STREAM:152 grpc.unary_stream_rpc_method_handler(servicer.unary_stream),153 _STREAM_UNARY:154 grpc.stream_unary_rpc_method_handler(servicer.stream_unary),155 _STREAM_STREAM:156 grpc.stream_stream_rpc_method_handler(157 servicer.stream_stream,158 request_deserializer=_REQUEST_DESERIALIZER,159 response_serializer=_RESPONSE_SERIALIZER),160 }161 return grpc.method_handlers_generic_handler(_SERVICE, method_handlers)162class MetadataCodeDetailsTest(unittest.TestCase):163 def setUp(self):164 self._servicer = _Servicer()165 self._server = test_common.test_server()166 self._server.add_generic_rpc_handlers(167 (_generic_handler(self._servicer),))168 port = self._server.add_insecure_port('[::]:0')169 self._server.start()170 channel = grpc.insecure_channel('localhost:{}'.format(port))171 self._unary_unary = channel.unary_unary(172 '/'.join((173 '',174 _SERVICE,175 _UNARY_UNARY,176 )),177 request_serializer=_REQUEST_SERIALIZER,178 response_deserializer=_RESPONSE_DESERIALIZER,179 )180 self._unary_stream = channel.unary_stream('/'.join((181 '',182 _SERVICE,183 _UNARY_STREAM,184 )),)185 self._stream_unary = channel.stream_unary('/'.join((186 '',187 _SERVICE,188 _STREAM_UNARY,189 )),)190 self._stream_stream = channel.stream_stream(191 '/'.join((192 '',193 _SERVICE,194 _STREAM_STREAM,195 )),196 request_serializer=_REQUEST_SERIALIZER,197 response_deserializer=_RESPONSE_DESERIALIZER,198 )199 def testSuccessfulUnaryUnary(self):200 self._servicer.set_details(_DETAILS)201 unused_response, call = self._unary_unary.with_call(202 object(), metadata=_CLIENT_METADATA)203 self.assertTrue(204 test_common.metadata_transmitted(205 _CLIENT_METADATA, self._servicer.received_client_metadata()))206 self.assertTrue(207 test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,208 call.initial_metadata()))209 self.assertTrue(210 test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,211 call.trailing_metadata()))212 self.assertIs(grpc.StatusCode.OK, call.code())213 self.assertEqual(_DETAILS, call.details())214 def testSuccessfulUnaryStream(self):215 self._servicer.set_details(_DETAILS)216 response_iterator_call = self._unary_stream(217 _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)218 received_initial_metadata = response_iterator_call.initial_metadata()219 list(response_iterator_call)220 self.assertTrue(221 test_common.metadata_transmitted(222 _CLIENT_METADATA, self._servicer.received_client_metadata()))223 self.assertTrue(224 test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,225 received_initial_metadata))226 self.assertTrue(227 test_common.metadata_transmitted(228 _SERVER_TRAILING_METADATA,229 response_iterator_call.trailing_metadata()))230 self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())231 self.assertEqual(_DETAILS, response_iterator_call.details())232 def testSuccessfulStreamUnary(self):233 self._servicer.set_details(_DETAILS)234 unused_response, call = self._stream_unary.with_call(235 iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),236 metadata=_CLIENT_METADATA)237 self.assertTrue(238 test_common.metadata_transmitted(239 _CLIENT_METADATA, self._servicer.received_client_metadata()))240 self.assertTrue(241 test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,242 call.initial_metadata()))243 self.assertTrue(244 test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,245 call.trailing_metadata()))246 self.assertIs(grpc.StatusCode.OK, call.code())247 self.assertEqual(_DETAILS, call.details())248 def testSuccessfulStreamStream(self):249 self._servicer.set_details(_DETAILS)250 response_iterator_call = self._stream_stream(251 iter([object()] * test_constants.STREAM_LENGTH),252 metadata=_CLIENT_METADATA)253 received_initial_metadata = response_iterator_call.initial_metadata()254 list(response_iterator_call)255 self.assertTrue(256 test_common.metadata_transmitted(257 _CLIENT_METADATA, self._servicer.received_client_metadata()))258 self.assertTrue(259 test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,260 received_initial_metadata))261 self.assertTrue(262 test_common.metadata_transmitted(263 _SERVER_TRAILING_METADATA,264 response_iterator_call.trailing_metadata()))265 self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())266 self.assertEqual(_DETAILS, response_iterator_call.details())267 def testAbortedUnaryUnary(self):268 test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,269 _EXPECTED_DETAILS)270 for abort_code, expected_code, expected_details in test_cases:271 self._servicer.set_code(abort_code)272 self._servicer.set_details(_DETAILS)273 self._servicer.set_abort_call()274 with self.assertRaises(grpc.RpcError) as exception_context:275 self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)276 self.assertTrue(277 test_common.metadata_transmitted(278 _CLIENT_METADATA,279 self._servicer.received_client_metadata()))280 self.assertTrue(281 test_common.metadata_transmitted(282 _SERVER_INITIAL_METADATA,283 exception_context.exception.initial_metadata()))284 self.assertTrue(285 test_common.metadata_transmitted(286 _SERVER_TRAILING_METADATA,287 exception_context.exception.trailing_metadata()))288 self.assertIs(expected_code, exception_context.exception.code())289 self.assertEqual(expected_details,290 exception_context.exception.details())291 def testAbortedUnaryStream(self):292 test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,293 _EXPECTED_DETAILS)294 for abort_code, expected_code, expected_details in test_cases:295 self._servicer.set_code(abort_code)296 self._servicer.set_details(_DETAILS)297 self._servicer.set_abort_call()298 response_iterator_call = self._unary_stream(299 _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)300 received_initial_metadata = \301 response_iterator_call.initial_metadata()302 with self.assertRaises(grpc.RpcError):303 self.assertEqual(len(list(response_iterator_call)), 0)304 self.assertTrue(305 test_common.metadata_transmitted(306 _CLIENT_METADATA,307 self._servicer.received_client_metadata()))308 self.assertTrue(309 test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,310 received_initial_metadata))311 self.assertTrue(312 test_common.metadata_transmitted(313 _SERVER_TRAILING_METADATA,314 response_iterator_call.trailing_metadata()))315 self.assertIs(expected_code, response_iterator_call.code())316 self.assertEqual(expected_details, response_iterator_call.details())317 def testAbortedStreamUnary(self):318 test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,319 _EXPECTED_DETAILS)320 for abort_code, expected_code, expected_details in test_cases:321 self._servicer.set_code(abort_code)322 self._servicer.set_details(_DETAILS)323 self._servicer.set_abort_call()324 with self.assertRaises(grpc.RpcError) as exception_context:325 self._stream_unary.with_call(326 iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),327 metadata=_CLIENT_METADATA)328 self.assertTrue(329 test_common.metadata_transmitted(330 _CLIENT_METADATA,331 self._servicer.received_client_metadata()))332 self.assertTrue(333 test_common.metadata_transmitted(334 _SERVER_INITIAL_METADATA,335 exception_context.exception.initial_metadata()))336 self.assertTrue(337 test_common.metadata_transmitted(338 _SERVER_TRAILING_METADATA,339 exception_context.exception.trailing_metadata()))340 self.assertIs(expected_code, exception_context.exception.code())341 self.assertEqual(expected_details,342 exception_context.exception.details())343 def testAbortedStreamStream(self):344 test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,345 _EXPECTED_DETAILS)346 for abort_code, expected_code, expected_details in test_cases:347 self._servicer.set_code(abort_code)348 self._servicer.set_details(_DETAILS)349 self._servicer.set_abort_call()350 response_iterator_call = self._stream_stream(351 iter([object()] * test_constants.STREAM_LENGTH),352 metadata=_CLIENT_METADATA)353 received_initial_metadata = \354 response_iterator_call.initial_metadata()355 with self.assertRaises(grpc.RpcError):356 self.assertEqual(len(list(response_iterator_call)), 0)357 self.assertTrue(358 test_common.metadata_transmitted(359 _CLIENT_METADATA,360 self._servicer.received_client_metadata()))361 self.assertTrue(362 test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,363 received_initial_metadata))364 self.assertTrue(365 test_common.metadata_transmitted(366 _SERVER_TRAILING_METADATA,367 response_iterator_call.trailing_metadata()))368 self.assertIs(expected_code, response_iterator_call.code())369 self.assertEqual(expected_details, response_iterator_call.details())370 def testCustomCodeUnaryUnary(self):371 self._servicer.set_code(_NON_OK_CODE)372 self._servicer.set_details(_DETAILS)373 with self.assertRaises(grpc.RpcError) as exception_context:374 self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)375 self.assertTrue(376 test_common.metadata_transmitted(377 _CLIENT_METADATA, self._servicer.received_client_metadata()))378 self.assertTrue(379 test_common.metadata_transmitted(380 _SERVER_INITIAL_METADATA,381 exception_context.exception.initial_metadata()))382 self.assertTrue(383 test_common.metadata_transmitted(384 _SERVER_TRAILING_METADATA,385 exception_context.exception.trailing_metadata()))386 self.assertIs(_NON_OK_CODE, exception_context.exception.code())387 self.assertEqual(_DETAILS, exception_context.exception.details())388 def testCustomCodeUnaryStream(self):389 self._servicer.set_code(_NON_OK_CODE)390 self._servicer.set_details(_DETAILS)391 response_iterator_call = self._unary_stream(392 _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)393 received_initial_metadata = response_iterator_call.initial_metadata()394 with self.assertRaises(grpc.RpcError):395 list(response_iterator_call)396 self.assertTrue(397 test_common.metadata_transmitted(398 _CLIENT_METADATA, self._servicer.received_client_metadata()))399 self.assertTrue(400 test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,401 received_initial_metadata))402 self.assertTrue(403 test_common.metadata_transmitted(404 _SERVER_TRAILING_METADATA,405 response_iterator_call.trailing_metadata()))406 self.assertIs(_NON_OK_CODE, response_iterator_call.code())407 self.assertEqual(_DETAILS, response_iterator_call.details())408 def testCustomCodeStreamUnary(self):409 self._servicer.set_code(_NON_OK_CODE)410 self._servicer.set_details(_DETAILS)411 with self.assertRaises(grpc.RpcError) as exception_context:412 self._stream_unary.with_call(413 iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),414 metadata=_CLIENT_METADATA)415 self.assertTrue(416 test_common.metadata_transmitted(417 _CLIENT_METADATA, self._servicer.received_client_metadata()))418 self.assertTrue(419 test_common.metadata_transmitted(420 _SERVER_INITIAL_METADATA,421 exception_context.exception.initial_metadata()))422 self.assertTrue(423 test_common.metadata_transmitted(424 _SERVER_TRAILING_METADATA,425 exception_context.exception.trailing_metadata()))426 self.assertIs(_NON_OK_CODE, exception_context.exception.code())427 self.assertEqual(_DETAILS, exception_context.exception.details())428 def testCustomCodeStreamStream(self):429 self._servicer.set_code(_NON_OK_CODE)430 self._servicer.set_details(_DETAILS)431 response_iterator_call = self._stream_stream(432 iter([object()] * test_constants.STREAM_LENGTH),433 metadata=_CLIENT_METADATA)434 received_initial_metadata = response_iterator_call.initial_metadata()435 with self.assertRaises(grpc.RpcError) as exception_context:436 list(response_iterator_call)437 self.assertTrue(438 test_common.metadata_transmitted(439 _CLIENT_METADATA, self._servicer.received_client_metadata()))440 self.assertTrue(441 test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,442 received_initial_metadata))443 self.assertTrue(444 test_common.metadata_transmitted(445 _SERVER_TRAILING_METADATA,446 exception_context.exception.trailing_metadata()))447 self.assertIs(_NON_OK_CODE, exception_context.exception.code())448 self.assertEqual(_DETAILS, exception_context.exception.details())449 def testCustomCodeExceptionUnaryUnary(self):450 self._servicer.set_code(_NON_OK_CODE)451 self._servicer.set_details(_DETAILS)452 self._servicer.set_exception()453 with self.assertRaises(grpc.RpcError) as exception_context:454 self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)455 self.assertTrue(456 test_common.metadata_transmitted(457 _CLIENT_METADATA, self._servicer.received_client_metadata()))458 self.assertTrue(459 test_common.metadata_transmitted(460 _SERVER_INITIAL_METADATA,461 exception_context.exception.initial_metadata()))462 self.assertTrue(463 test_common.metadata_transmitted(464 _SERVER_TRAILING_METADATA,465 exception_context.exception.trailing_metadata()))466 self.assertIs(_NON_OK_CODE, exception_context.exception.code())467 self.assertEqual(_DETAILS, exception_context.exception.details())468 def testCustomCodeExceptionUnaryStream(self):469 self._servicer.set_code(_NON_OK_CODE)470 self._servicer.set_details(_DETAILS)471 self._servicer.set_exception()472 response_iterator_call = self._unary_stream(473 _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)474 received_initial_metadata = response_iterator_call.initial_metadata()475 with self.assertRaises(grpc.RpcError):476 list(response_iterator_call)477 self.assertTrue(478 test_common.metadata_transmitted(479 _CLIENT_METADATA, self._servicer.received_client_metadata()))480 self.assertTrue(481 test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,482 received_initial_metadata))483 self.assertTrue(484 test_common.metadata_transmitted(485 _SERVER_TRAILING_METADATA,486 response_iterator_call.trailing_metadata()))487 self.assertIs(_NON_OK_CODE, response_iterator_call.code())488 self.assertEqual(_DETAILS, response_iterator_call.details())489 def testCustomCodeExceptionStreamUnary(self):490 self._servicer.set_code(_NON_OK_CODE)491 self._servicer.set_details(_DETAILS)492 self._servicer.set_exception()493 with self.assertRaises(grpc.RpcError) as exception_context:494 self._stream_unary.with_call(495 iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),496 metadata=_CLIENT_METADATA)497 self.assertTrue(498 test_common.metadata_transmitted(499 _CLIENT_METADATA, self._servicer.received_client_metadata()))500 self.assertTrue(501 test_common.metadata_transmitted(502 _SERVER_INITIAL_METADATA,503 exception_context.exception.initial_metadata()))504 self.assertTrue(505 test_common.metadata_transmitted(506 _SERVER_TRAILING_METADATA,507 exception_context.exception.trailing_metadata()))508 self.assertIs(_NON_OK_CODE, exception_context.exception.code())509 self.assertEqual(_DETAILS, exception_context.exception.details())510 def testCustomCodeExceptionStreamStream(self):511 self._servicer.set_code(_NON_OK_CODE)512 self._servicer.set_details(_DETAILS)513 self._servicer.set_exception()514 response_iterator_call = self._stream_stream(515 iter([object()] * test_constants.STREAM_LENGTH),516 metadata=_CLIENT_METADATA)517 received_initial_metadata = response_iterator_call.initial_metadata()518 with self.assertRaises(grpc.RpcError):519 list(response_iterator_call)520 self.assertTrue(521 test_common.metadata_transmitted(522 _CLIENT_METADATA, self._servicer.received_client_metadata()))523 self.assertTrue(524 test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,525 received_initial_metadata))526 self.assertTrue(527 test_common.metadata_transmitted(528 _SERVER_TRAILING_METADATA,529 response_iterator_call.trailing_metadata()))530 self.assertIs(_NON_OK_CODE, response_iterator_call.code())531 self.assertEqual(_DETAILS, response_iterator_call.details())532 def testCustomCodeReturnNoneUnaryUnary(self):533 self._servicer.set_code(_NON_OK_CODE)534 self._servicer.set_details(_DETAILS)535 self._servicer.set_return_none()536 with self.assertRaises(grpc.RpcError) as exception_context:537 self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)538 self.assertTrue(539 test_common.metadata_transmitted(540 _CLIENT_METADATA, self._servicer.received_client_metadata()))541 self.assertTrue(542 test_common.metadata_transmitted(543 _SERVER_INITIAL_METADATA,544 exception_context.exception.initial_metadata()))545 self.assertTrue(546 test_common.metadata_transmitted(547 _SERVER_TRAILING_METADATA,548 exception_context.exception.trailing_metadata()))549 self.assertIs(_NON_OK_CODE, exception_context.exception.code())550 self.assertEqual(_DETAILS, exception_context.exception.details())551 def testCustomCodeReturnNoneStreamUnary(self):552 self._servicer.set_code(_NON_OK_CODE)553 self._servicer.set_details(_DETAILS)554 self._servicer.set_return_none()555 with self.assertRaises(grpc.RpcError) as exception_context:556 self._stream_unary.with_call(557 iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),558 metadata=_CLIENT_METADATA)559 self.assertTrue(560 test_common.metadata_transmitted(561 _CLIENT_METADATA, self._servicer.received_client_metadata()))562 self.assertTrue(563 test_common.metadata_transmitted(564 _SERVER_INITIAL_METADATA,565 exception_context.exception.initial_metadata()))566 self.assertTrue(567 test_common.metadata_transmitted(568 _SERVER_TRAILING_METADATA,569 exception_context.exception.trailing_metadata()))570 self.assertIs(_NON_OK_CODE, exception_context.exception.code())571 self.assertEqual(_DETAILS, exception_context.exception.details())572if __name__ == '__main__':...
_invalid_metadata_test.py
Source:_invalid_metadata_test.py
1# Copyright 2016 gRPC authors.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Test of RPCs made against gRPC Python's application-layer API."""15import unittest16import grpc17from tests.unit.framework.common import test_constants18_SERIALIZE_REQUEST = lambda bytestring: bytestring * 219_DESERIALIZE_REQUEST = lambda bytestring: bytestring[len(bytestring) // 2:]20_SERIALIZE_RESPONSE = lambda bytestring: bytestring * 321_DESERIALIZE_RESPONSE = lambda bytestring: bytestring[:len(bytestring) // 3]22_UNARY_UNARY = '/test/UnaryUnary'23_UNARY_STREAM = '/test/UnaryStream'24_STREAM_UNARY = '/test/StreamUnary'25_STREAM_STREAM = '/test/StreamStream'26def _unary_unary_multi_callable(channel):27 return channel.unary_unary(_UNARY_UNARY)28def _unary_stream_multi_callable(channel):29 return channel.unary_stream(30 _UNARY_STREAM,31 request_serializer=_SERIALIZE_REQUEST,32 response_deserializer=_DESERIALIZE_RESPONSE)33def _stream_unary_multi_callable(channel):34 return channel.stream_unary(35 _STREAM_UNARY,36 request_serializer=_SERIALIZE_REQUEST,37 response_deserializer=_DESERIALIZE_RESPONSE)38def _stream_stream_multi_callable(channel):39 return channel.stream_stream(_STREAM_STREAM)40class InvalidMetadataTest(unittest.TestCase):41 def setUp(self):42 self._channel = grpc.insecure_channel('localhost:8080')43 self._unary_unary = _unary_unary_multi_callable(self._channel)44 self._unary_stream = _unary_stream_multi_callable(self._channel)45 self._stream_unary = _stream_unary_multi_callable(self._channel)46 self._stream_stream = _stream_stream_multi_callable(self._channel)47 def testUnaryRequestBlockingUnaryResponse(self):48 request = b'\x07\x08'49 metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),)50 expected_error_details = "metadata was invalid: %s" % metadata51 with self.assertRaises(ValueError) as exception_context:52 self._unary_unary(request, metadata=metadata)53 self.assertIn(expected_error_details, str(exception_context.exception))54 def testUnaryRequestBlockingUnaryResponseWithCall(self):55 request = b'\x07\x08'56 metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponseWithCall'),)57 expected_error_details = "metadata was invalid: %s" % metadata58 with self.assertRaises(ValueError) as exception_context:59 self._unary_unary.with_call(request, metadata=metadata)60 self.assertIn(expected_error_details, str(exception_context.exception))61 def testUnaryRequestFutureUnaryResponse(self):62 request = b'\x07\x08'63 metadata = (('InVaLiD', 'UnaryRequestFutureUnaryResponse'),)64 expected_error_details = "metadata was invalid: %s" % metadata65 response_future = self._unary_unary.future(request, metadata=metadata)66 with self.assertRaises(grpc.RpcError) as exception_context:67 response_future.result()68 self.assertEqual(exception_context.exception.details(),69 expected_error_details)70 self.assertEqual(exception_context.exception.code(),71 grpc.StatusCode.INTERNAL)72 self.assertEqual(response_future.details(), expected_error_details)73 self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)74 def testUnaryRequestStreamResponse(self):75 request = b'\x37\x58'76 metadata = (('InVaLiD', 'UnaryRequestStreamResponse'),)77 expected_error_details = "metadata was invalid: %s" % metadata78 response_iterator = self._unary_stream(request, metadata=metadata)79 with self.assertRaises(grpc.RpcError) as exception_context:80 next(response_iterator)81 self.assertEqual(exception_context.exception.details(),82 expected_error_details)83 self.assertEqual(exception_context.exception.code(),84 grpc.StatusCode.INTERNAL)85 self.assertEqual(response_iterator.details(), expected_error_details)86 self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)87 def testStreamRequestBlockingUnaryResponse(self):88 request_iterator = (89 b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))90 metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponse'),)91 expected_error_details = "metadata was invalid: %s" % metadata92 with self.assertRaises(ValueError) as exception_context:93 self._stream_unary(request_iterator, metadata=metadata)94 self.assertIn(expected_error_details, str(exception_context.exception))95 def testStreamRequestBlockingUnaryResponseWithCall(self):96 request_iterator = (97 b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))98 metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponseWithCall'),)99 expected_error_details = "metadata was invalid: %s" % metadata100 multi_callable = _stream_unary_multi_callable(self._channel)101 with self.assertRaises(ValueError) as exception_context:102 multi_callable.with_call(request_iterator, metadata=metadata)103 self.assertIn(expected_error_details, str(exception_context.exception))104 def testStreamRequestFutureUnaryResponse(self):105 request_iterator = (106 b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))107 metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),)108 expected_error_details = "metadata was invalid: %s" % metadata109 response_future = self._stream_unary.future(110 request_iterator, metadata=metadata)111 with self.assertRaises(grpc.RpcError) as exception_context:112 response_future.result()113 self.assertEqual(exception_context.exception.details(),114 expected_error_details)115 self.assertEqual(exception_context.exception.code(),116 grpc.StatusCode.INTERNAL)117 self.assertEqual(response_future.details(), expected_error_details)118 self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)119 def testStreamRequestStreamResponse(self):120 request_iterator = (121 b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))122 metadata = (('InVaLiD', 'StreamRequestStreamResponse'),)123 expected_error_details = "metadata was invalid: %s" % metadata124 response_iterator = self._stream_stream(125 request_iterator, metadata=metadata)126 with self.assertRaises(grpc.RpcError) as exception_context:127 next(response_iterator)128 self.assertEqual(exception_context.exception.details(),129 expected_error_details)130 self.assertEqual(exception_context.exception.code(),131 grpc.StatusCode.INTERNAL)132 self.assertEqual(response_iterator.details(), expected_error_details)133 self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)134if __name__ == '__main__':...
Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!