Best Python code snippet using toolium_python
rpc_op_test_base.py
Source:rpc_op_test_base.py
...28I_WARNED_YOU = 'I warned you!'29class RpcOpTestBase(object):30 # pylint: disable=missing-docstring,invalid-name31 """Base class for RpcOp tests."""32 def get_method_name(self, suffix):33 raise NotImplementedError34 def rpc(self, *args, **kwargs):35 return rpc_op.rpc(*args, protocol=self._protocol, **kwargs)36 def try_rpc(self, *args, **kwargs):37 return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs)38 def testScalarHostPortRpc(self):39 with self.cached_session() as sess:40 request_tensors = (41 test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())42 response_tensors = self.rpc(43 method=self.get_method_name('Increment'),44 address=self._address,45 request=request_tensors)46 self.assertEqual(response_tensors.shape, ())47 response_values = sess.run(response_tensors)48 response_message = test_example_pb2.TestCase()49 self.assertTrue(response_message.ParseFromString(response_values))50 self.assertAllEqual([2, 3, 4], response_message.values)51 def testScalarHostPortTryRpc(self):52 with self.cached_session() as sess:53 request_tensors = (54 test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())55 response_tensors, status_code, status_message = self.try_rpc(56 method=self.get_method_name('Increment'),57 address=self._address,58 request=request_tensors)59 self.assertEqual(status_code.shape, ())60 self.assertEqual(status_message.shape, ())61 self.assertEqual(response_tensors.shape, ())62 response_values, status_code_values, status_message_values = (63 sess.run((response_tensors, status_code, status_message)))64 response_message = test_example_pb2.TestCase()65 self.assertTrue(response_message.ParseFromString(response_values))66 self.assertAllEqual([2, 3, 4], response_message.values)67 # For the base Rpc op, don't expect to get error status back.68 self.assertEqual(errors.OK, status_code_values)69 self.assertEqual(b'', status_message_values)70 def testEmptyHostPortRpc(self):71 with self.cached_session() as sess:72 request_tensors = []73 response_tensors = self.rpc(74 method=self.get_method_name('Increment'),75 address=self._address,76 request=request_tensors)77 self.assertAllEqual(response_tensors.shape, [0])78 response_values = sess.run(response_tensors)79 self.assertAllEqual(response_values.shape, [0])80 def testInvalidMethod(self):81 for method in [82 '/InvalidService.Increment',83 self.get_method_name('InvalidMethodName')84 ]:85 with self.cached_session() as sess:86 with self.assertRaisesOpError(self.invalid_method_string):87 sess.run(self.rpc(method=method, address=self._address, request=''))88 _, status_code_value, status_message_value = sess.run(89 self.try_rpc(method=method, address=self._address, request=''))90 self.assertEqual(errors.UNIMPLEMENTED, status_code_value)91 self.assertTrue(92 self.invalid_method_string in status_message_value.decode('ascii'))93 def testInvalidAddress(self):94 # This covers the case of address='' and address='localhost:293874293874'95 address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'96 with self.cached_session() as sess:97 with self.assertRaises(errors.UnavailableError):98 sess.run(99 self.rpc(100 method=self.get_method_name('Increment'),101 address=address,102 request=''))103 _, status_code_value, status_message_value = sess.run(104 self.try_rpc(105 method=self.get_method_name('Increment'),106 address=address,107 request=''))108 self.assertEqual(errors.UNAVAILABLE, status_code_value)109 def testAlwaysFailingMethod(self):110 with self.cached_session() as sess:111 response_tensors = self.rpc(112 method=self.get_method_name('AlwaysFailWithInvalidArgument'),113 address=self._address,114 request='')115 self.assertEqual(response_tensors.shape, ())116 with self.assertRaisesOpError(I_WARNED_YOU):117 sess.run(response_tensors)118 response_tensors, status_code, status_message = self.try_rpc(119 method=self.get_method_name('AlwaysFailWithInvalidArgument'),120 address=self._address,121 request='')122 self.assertEqual(response_tensors.shape, ())123 self.assertEqual(status_code.shape, ())124 self.assertEqual(status_message.shape, ())125 status_code_value, status_message_value = sess.run((status_code,126 status_message))127 self.assertEqual(errors.INVALID_ARGUMENT, status_code_value)128 self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))129 def testSometimesFailingMethodWithManyRequests(self):130 with self.cached_session() as sess:131 # Fail hard by default.132 response_tensors = self.rpc(133 method=self.get_method_name('SometimesFailWithInvalidArgument'),134 address=self._address,135 request=[''] * 20)136 self.assertEqual(response_tensors.shape, (20,))137 with self.assertRaisesOpError(I_WARNED_YOU):138 sess.run(response_tensors)139 # Don't fail hard, use TryRpc - return the failing status instead.140 response_tensors, status_code, status_message = self.try_rpc(141 method=self.get_method_name('SometimesFailWithInvalidArgument'),142 address=self._address,143 request=[''] * 20)144 self.assertEqual(response_tensors.shape, (20,))145 self.assertEqual(status_code.shape, (20,))146 self.assertEqual(status_message.shape, (20,))147 status_code_values, status_message_values = sess.run((status_code,148 status_message))149 self.assertTrue([150 x in (errors.OK, errors.INVALID_ARGUMENT) for x in status_code_values151 ])152 expected_message_values = np.where(153 status_code_values == errors.INVALID_ARGUMENT,154 I_WARNED_YOU.encode('ascii'), b'')155 for msg, expected in zip(status_message_values, expected_message_values):156 self.assertTrue(expected in msg,157 '"%s" did not contain "%s"' % (msg, expected))158 def testVecHostPortRpc(self):159 with self.cached_session() as sess:160 request_tensors = [161 test_example_pb2.TestCase(162 values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)163 ]164 response_tensors = self.rpc(165 method=self.get_method_name('Increment'),166 address=self._address,167 request=request_tensors)168 self.assertEqual(response_tensors.shape, (20,))169 response_values = sess.run(response_tensors)170 self.assertEqual(response_values.shape, (20,))171 for i in range(20):172 response_message = test_example_pb2.TestCase()173 self.assertTrue(response_message.ParseFromString(response_values[i]))174 self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)175 def testVecHostPortManyParallelRpcs(self):176 with self.cached_session() as sess:177 request_tensors = [178 test_example_pb2.TestCase(179 values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)180 ]181 many_response_tensors = [182 self.rpc(183 method=self.get_method_name('Increment'),184 address=self._address,185 request=request_tensors) for _ in range(10)186 ]187 # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests.188 many_response_values = sess.run(many_response_tensors)189 self.assertEqual(10, len(many_response_values))190 for response_values in many_response_values:191 self.assertEqual(response_values.shape, (20,))192 for i in range(20):193 response_message = test_example_pb2.TestCase()194 self.assertTrue(response_message.ParseFromString(response_values[i]))195 self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)196 def testVecHostPortRpcUsingEncodeAndDecodeProto(self):197 with self.cached_session() as sess:198 request_tensors = proto_ops.encode_proto(199 message_type='tensorflow.contrib.rpc.TestCase',200 field_names=['values'],201 sizes=[[3]] * 20,202 values=[203 [[i, i + 1, i + 2] for i in range(20)],204 ])205 response_tensor_strings = self.rpc(206 method=self.get_method_name('Increment'),207 address=self._address,208 request=request_tensors)209 _, (response_shape,) = proto_ops.decode_proto(210 bytes=response_tensor_strings,211 message_type='tensorflow.contrib.rpc.TestCase',212 field_names=['values'],213 output_types=[dtypes.int32])214 response_shape_values = sess.run(response_shape)215 self.assertAllEqual([[i + 1, i + 2, i + 3]216 for i in range(20)], response_shape_values)217 def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):218 with self.cached_session() as sess:219 request_tensors = [''] * 25 # This will launch 25 RPC requests.220 response_tensors = self.rpc(221 method=self.get_method_name('SleepForever'),222 address=self._address,223 request=request_tensors)224 for timeout_ms in [1, 500, 1000]:225 options = config_pb2.RunOptions(timeout_in_ms=timeout_ms)226 with self.assertRaises((errors.UnavailableError,227 errors.DeadlineExceededError)):228 sess.run(response_tensors, options=options)229 def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):230 with self.cached_session() as sess:231 request_tensors = [''] * 25 # This will launch 25 RPC requests.232 response_tensors = self.rpc(233 method=self.get_method_name('SleepForever'),234 address=self._address,235 timeout_in_ms=1000,236 request=request_tensors)237 with self.assertRaises(errors.DeadlineExceededError):238 sess.run(response_tensors)239 def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):240 with self.cached_session() as sess:241 response_tensors, status_code, status_message = self.try_rpc(242 method=self.get_method_name('SometimesSleepForever'),243 timeout_in_ms=1000,244 address=self._address,245 request=[''] * 20)246 self.assertEqual(response_tensors.shape, (20,))247 self.assertEqual(status_code.shape, (20,))248 self.assertEqual(status_message.shape, (20,))249 status_code_values = sess.run(status_code)250 self.assertTrue([251 x in (errors.OK, errors.DEADLINE_EXCEEDED) for x in status_code_values252 ])253 def testTryRpcWithMultipleAddressesSingleRequest(self):254 flatten = lambda x: list(itertools.chain.from_iterable(x))255 with self.cached_session() as sess:256 addresses = flatten([[257 self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'258 ] for _ in range(10)])259 request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()260 response_tensors, status_code, _ = self.try_rpc(261 method=self.get_method_name('Increment'),262 address=addresses,263 request=request)264 response_tensors_values, status_code_values = sess.run((response_tensors,265 status_code))266 self.assertAllEqual(267 flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),268 status_code_values)269 for i in range(10):270 self.assertTrue(response_tensors_values[2 * i])271 self.assertFalse(response_tensors_values[2 * i + 1])272 def testTryRpcWithMultipleMethodsSingleRequest(self):273 flatten = lambda x: list(itertools.chain.from_iterable(x))274 with self.cached_session() as sess:275 methods = flatten(276 [[self.get_method_name('Increment'), 'InvalidMethodName']277 for _ in range(10)])278 request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()279 response_tensors, status_code, _ = self.try_rpc(280 method=methods, address=self._address, request=request)281 response_tensors_values, status_code_values = sess.run((response_tensors,282 status_code))283 self.assertAllEqual(284 flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)),285 status_code_values)286 for i in range(10):287 self.assertTrue(response_tensors_values[2 * i])288 self.assertFalse(response_tensors_values[2 * i + 1])289 def testTryRpcWithMultipleAddressesAndRequests(self):290 flatten = lambda x: list(itertools.chain.from_iterable(x))291 with self.cached_session() as sess:292 addresses = flatten([[293 self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'294 ] for _ in range(10)])295 requests = [296 test_example_pb2.TestCase(297 values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)298 ]299 response_tensors, status_code, _ = self.try_rpc(300 method=self.get_method_name('Increment'),301 address=addresses,302 request=requests)303 response_tensors_values, status_code_values = sess.run((response_tensors,304 status_code))305 self.assertAllEqual(306 flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),307 status_code_values)308 for i in range(20):309 if i % 2 == 1:310 self.assertFalse(response_tensors_values[i])311 else:312 response_message = test_example_pb2.TestCase()313 self.assertTrue(314 response_message.ParseFromString(response_tensors_values[i]))...
Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!