Best Python code snippet using nose2
ops_test.py
Source:ops_test.py
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15from __future__ import absolute_import16from __future__ import division17from __future__ import print_function18import numpy as np19from six.moves import range # pylint: disable=redefined-builtin20from tensorflow.contrib.labeled_tensor.python.ops import core21from tensorflow.contrib.labeled_tensor.python.ops import ops22from tensorflow.contrib.labeled_tensor.python.ops import test_util23from tensorflow.python.framework import constant_op24from tensorflow.python.framework import dtypes25from tensorflow.python.framework import errors_impl26from tensorflow.python.ops import array_ops27from tensorflow.python.ops import math_ops28from tensorflow.python.ops import string_ops29from tensorflow.python.platform import test as test_lib30class Base(test_util.Base):31 def setUp(self):32 super(Base, self).setUp()33 self.x_size = 734 self.channel_size = 335 self.z_size = 436 self.probs_size = 1137 tensor = math_ops.range(0, self.x_size * self.channel_size * self.z_size *38 self.probs_size)39 tensor = array_ops.reshape(40 tensor, [self.x_size, self.channel_size, self.z_size, self.probs_size])41 a0 = ('x', range(self.x_size))42 a1 = ('channel', ['red', 'green', 'blue'])43 a2 = 'z'44 a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))45 self.tensor = tensor46 self.a0 = a047 self.a1 = a148 self.a2 = a249 self.a2_resolved = ('z', self.z_size)50 self.a3 = a351 self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])52 self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0})53 self.x_probs_lt = ops.select(self.x_probs_lt, {'channel': 'red'})54 self.channel_probs_lt = core.slice_function(self.original_lt,55 {'x': 3,56 'z': 0})57class SelectTest(Base):58 def test_name(self):59 select_lt = ops.select(self.original_lt, {'channel': 'green'})60 self.assertIn('lt_select', select_lt.name)61 def test_scalar(self):62 select_lt = ops.select(self.original_lt, {'channel': 'green'})63 golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :],64 [self.a0, self.a2, self.a3])65 self.assertLabeledTensorsEqual(select_lt, golden_lt)66 def test_slice(self):67 select_lt = ops.select(self.original_lt, {'channel': slice('red', 'green')})68 a1_sliced = ('channel', ['red', 'green'])69 golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],70 [self.a0, a1_sliced, self.a2, self.a3])71 self.assertLabeledTensorsEqual(select_lt, golden_lt)72 def test_slices(self):73 select_lt = ops.select(self.original_lt,74 {'x': slice(1, 4),75 'channel': slice('green', None)})76 a0_sliced = ('x', range(1, 5))77 a1_sliced = ('channel', ['green', 'blue'])78 golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :],79 [a0_sliced, a1_sliced, self.a2, self.a3])80 self.assertLabeledTensorsEqual(select_lt, golden_lt)81 def test_list(self):82 select_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})83 a1_sliced = ('channel', ['red', 'green'])84 golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],85 [self.a0, a1_sliced, self.a2, self.a3])86 self.assertLabeledTensorsEqual(select_lt, golden_lt)87 def test_list_one_item(self):88 select_lt = ops.select(self.original_lt, {'channel': ['red']})89 a1_sliced = ('channel', ['red'])90 golden_lt = core.LabeledTensor(self.tensor[:, :1, :, :],91 [self.a0, a1_sliced, self.a2, self.a3])92 self.assertLabeledTensorsEqual(select_lt, golden_lt)93 def test_list_zero_items(self):94 select_lt = ops.select(self.original_lt, {'channel': []})95 golden_lt = core.LabeledTensor(self.tensor[:, :0, :, :],96 [self.a0, 'channel', self.a2, self.a3])97 self.assertLabeledTensorsEqual(select_lt, golden_lt)98 def test_scalars(self):99 select_lt = ops.select(self.original_lt, {'x': 1, 'channel': 'green'})100 golden_lt = core.LabeledTensor(self.tensor[1, 1, :, :], [self.a2, self.a3])101 self.assertLabeledTensorsEqual(select_lt, golden_lt)102 def test_tuple(self):103 original_lt = core.LabeledTensor(constant_op.constant([5, 6]),104 [('x', [(1, 2), (3, 4)])])105 select_lt = ops.select(original_lt, {'x': (1, 2)})106 golden_lt = core.LabeledTensor(constant_op.constant(5), [])107 self.assertLabeledTensorsEqual(select_lt, golden_lt)108 def test_invalid_input(self):109 with self.assertRaises(ValueError):110 ops.select(self.original_lt, {'foo': 1})111 with self.assertRaises(ValueError):112 ops.select(self.original_lt, {'z': 1})113 with self.assertRaises(KeyError):114 ops.select(self.original_lt, {'channel': 'purple'})115 with self.assertRaises(KeyError):116 ops.select(self.original_lt, {'channel': ['red', 'purple']})117 with self.assertRaises(NotImplementedError):118 ops.select(self.original_lt, {'channel': ['red'], 'x': [1]})119 with self.assertRaises(NotImplementedError):120 ops.select(self.original_lt, {'channel': ['red'], 'x': 1})121 with self.assertRaises(NotImplementedError):122 ops.select(self.original_lt, {'channel': slice('red', 'green', 2)})123class ConcatTest(Base):124 def setUp(self):125 super(ConcatTest, self).setUp()126 self.red_lt = ops.select(self.original_lt, {'channel': ['red']})127 self.green_lt = ops.select(self.original_lt, {'channel': ['green']})128 self.blue_lt = ops.select(self.original_lt, {'channel': ['blue']})129 def test_name(self):130 concat_lt = ops.concat([self.red_lt, self.blue_lt], 'channel')131 self.assertIn('lt_concat', concat_lt.name)132 def test(self):133 concat_lt = ops.concat([self.red_lt, self.green_lt], 'channel')134 golden_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})135 self.assertLabeledTensorsEqual(concat_lt, golden_lt)136 def test_transposed(self):137 green_transposed = core.transpose(self.green_lt,138 ['probs', 'channel', 'z', 'x'])139 with self.assertRaises(ValueError):140 ops.concat([self.red_lt, green_transposed], 'channel')141 def test_invalid_input(self):142 with self.assertRaises(ValueError):143 ops.concat([], 'channel')144 with self.assertRaises(ValueError):145 ops.concat([self.red_lt, self.red_lt], 'channel')146 with self.assertRaises(ValueError):147 ops.concat([self.red_lt, self.red_lt], 'foo')148class PackTest(Base):149 def test_name(self):150 pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')151 self.assertIn('lt_pack', pack_lt.name)152 def test(self):153 pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')154 golden_lt = core.LabeledTensor(155 array_ops.stack([self.original_lt.tensor, self.original_lt.tensor]),156 ['batch', self.a0, self.a1, self.a2, self.a3])157 self.assertLabeledTensorsEqual(pack_lt, golden_lt)158 def test_axis(self):159 pack_lt = ops.pack(160 [self.original_lt, self.original_lt], new_axis='batch', axis_position=4)161 golden_lt = core.LabeledTensor(162 array_ops.stack(163 [self.original_lt.tensor, self.original_lt.tensor], axis=4),164 [self.a0, self.a1, self.a2, self.a3, 'batch'])165 self.assertLabeledTensorsEqual(pack_lt, golden_lt)166 def test_invalid_input(self):167 with self.assertRaises(ValueError):168 ops.pack([self.original_lt, self.original_lt], 'channel')169class UnpackTest(Base):170 def test_name(self):171 unpack_lts = ops.unpack(self.original_lt)172 for t in unpack_lts:173 self.assertIn('lt_unpack', t.name)174 def test(self):175 unpack_lt = ops.unpack(self.original_lt)[0]176 golden_lt = core.LabeledTensor(177 array_ops.unstack(self.original_lt.tensor)[0],178 [self.a1, self.a2, self.a3])179 self.assertLabeledTensorsEqual(unpack_lt, golden_lt)180 def test_axis(self):181 unpack_lt = ops.unpack(self.original_lt, axis_name='z')[0]182 golden_lt = core.LabeledTensor(183 array_ops.unstack(184 self.original_lt.tensor, axis=2)[0], [self.a0, self.a1, self.a3])185 self.assertLabeledTensorsEqual(unpack_lt, golden_lt)186 def test_invalid_input(self):187 with self.assertRaises(ValueError):188 ops.unpack(self.original_lt, axis_name='not_found')189class ReshapeTest(Base):190 def test_name(self):191 reshape_lt = ops.reshape(self.original_lt, ['channel'], ['foo'])192 self.assertIn('lt_reshape', reshape_lt.name)193 def test_identity(self):194 reshape_lt = ops.reshape(self.original_lt,195 self.original_lt.axes.keys(),196 self.original_lt.axes.values())197 self.assertLabeledTensorsEqual(reshape_lt, self.original_lt)198 def test_known_size(self):199 new_dim_size = self.channel_size * self.z_size * self.probs_size200 reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],201 [('new_dim', new_dim_size)])202 golden_lt = core.LabeledTensor(203 array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),204 [self.original_lt.axes['x'], 'new_dim'])205 self.assertLabeledTensorsEqual(reshape_lt, golden_lt)206 def test_unknown_size(self):207 reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],208 ['new_dim'])209 golden_lt = core.LabeledTensor(210 array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),211 [self.original_lt.axes['x'], 'new_dim'])212 self.assertLabeledTensorsEqual(reshape_lt, golden_lt)213 def test_unknown_dimension(self):214 orig_lt = core.LabeledTensor(215 array_ops.placeholder(dtypes.float32, [None]), ['x'])216 reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)])217 self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)]))218 with self.test_session() as sess:219 result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]})220 np.testing.assert_array_equal(result, [[1], [2]])221 def test_with_labels(self):222 new_dim_size = self.channel_size * self.z_size * self.probs_size223 reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],224 [('new_dim', range(new_dim_size))])225 golden_lt = core.LabeledTensor(226 array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),227 [self.original_lt.axes['x'], ('new_dim', range(new_dim_size))])228 self.assertLabeledTensorsEqual(reshape_lt, golden_lt)229 def test_invalid_input(self):230 with self.assertRaisesRegexp(ValueError, 'not contained in the set'):231 ops.reshape(self.original_lt, ['foo'], ['bar'])232 with self.assertRaisesRegexp(core.AxisOrderError,233 'not a slice of axis names'):234 ops.reshape(self.original_lt, ['probs', 'z'], ['bar'])235 with self.assertRaisesRegexp(ValueError, 'at most one axis in new_axes'):236 ops.reshape(self.original_lt, ['probs'], ['foo', 'bar'])237class RenameAxisTest(Base):238 def test_name(self):239 rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')240 self.assertIn('lt_rename_axis', rename_axis_lt.name)241 def test_identity(self):242 rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'channel')243 self.assertLabeledTensorsEqual(rename_axis_lt, self.original_lt)244 def test_new_name(self):245 rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')246 expected_axes = [(name if name != 'channel' else 'foo', axis.value)247 for name, axis in self.original_lt.axes.items()]248 expected_lt = core.LabeledTensor(self.original_lt.tensor, expected_axes)249 self.assertLabeledTensorsEqual(rename_axis_lt, expected_lt)250 def test_invalid_input(self):251 with self.assertRaisesRegexp(ValueError, 'not contained in the set'):252 ops.rename_axis(self.original_lt, 'foo', 'bar')253class BatchTest(Base):254 def setUp(self):255 super(BatchTest, self).setUp()256 tensors = []257 for i in range(10):258 offset_lt = core.LabeledTensor(constant_op.constant(i), [])259 tensors.append(core.add(self.original_lt, offset_lt))260 self.pack_lt = ops.pack(tensors, 'batch')261 def test_name(self):262 batch_ops = ops.batch(263 [self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True)264 for bo in batch_ops:265 self.assertIn('lt_batch', bo.name)266 def test_enqueue_many(self):267 [batch_2_op] = ops.batch([self.pack_lt], batch_size=2, enqueue_many=True)268 self.assertEqual(len(batch_2_op.axes['batch']), 2)269 [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)270 self.assertLabeledTensorsEqual(self.pack_lt, batch_10_op)271 def test_no_enqueue_many(self):272 [batch_2_op] = ops.batch([self.original_lt], batch_size=2)273 self.assertEqual(len(batch_2_op.axes['batch']), 2)274 [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)275 self.assertLabeledTensorsEqual(276 ops.pack(10 * [self.original_lt], 'batch'), batch_10_op)277 def test_invalid_input(self):278 with self.assertRaises(ValueError):279 ops.batch([self.original_lt], 3, enqueue_many=True)280 def test_allow_smaller_final_batch(self):281 [batch_2_op] = ops.batch(282 [self.original_lt], batch_size=2, allow_smaller_final_batch=True)283 self.assertEqual(batch_2_op.axes['batch'].size, None)284class ShuffleBatchTest(Base):285 def setUp(self):286 super(ShuffleBatchTest, self).setUp()287 tensors = []288 for i in range(10):289 offset_lt = core.LabeledTensor(constant_op.constant(i), [])290 tensors.append(core.add(self.original_lt, offset_lt))291 self.pack_lt = ops.pack(tensors, 'batch')292 def test_name(self):293 batch_lts = ops.shuffle_batch(294 [self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True)295 for blt in batch_lts:296 self.assertIn('lt_shuffle_batch', blt.name)297 def test_enqueue_many(self):298 [batch_2_lt] = ops.shuffle_batch(299 [self.pack_lt],300 batch_size=2,301 enqueue_many=True,302 min_after_dequeue=8,303 seed=0)304 self.assertEqual(len(batch_2_lt.axes['batch']), 2)305 [batch_10_lt] = ops.batch([batch_2_lt], batch_size=10, enqueue_many=True)306 self.assertEqual(batch_10_lt.axes, self.pack_lt.axes)307 [batch_10, pack] = self.eval([batch_10_lt.tensor, self.pack_lt.tensor])308 self.assertFalse((batch_10 == pack).all())309 def test_allow_smaller_final_batch(self):310 [batch_2_op] = ops.shuffle_batch(311 [self.original_lt], batch_size=2, allow_smaller_final_batch=True)312 self.assertEqual(batch_2_op.axes['batch'].size, None)313class RandomCropTest(Base):314 def test_name(self):315 crop_lt = ops.random_crop(self.original_lt, {'probs': 3})316 self.assertIn('lt_random_crop', crop_lt.name)317 def test_single(self):318 crop_lt = ops.random_crop(self.original_lt, {'probs': 3})319 self.assertEqual(320 core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 3)]),321 crop_lt.axes)322 def test_double(self):323 crop_lt = ops.random_crop(self.original_lt, {'probs': 3, 'channel': 2})324 self.assertEqual(325 core.Axes([self.a0, ('channel', 2), self.a2_resolved, ('probs', 3)]),326 crop_lt.axes)327 def test_size1(self):328 crop_lt = ops.random_crop(self.original_lt, {'probs': 1})329 self.assertEqual(330 core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 1)]),331 crop_lt.axes)332 def test_different_seeds(self):333 crop_0_lt = ops.random_crop(334 self.original_lt, {'probs': 3,335 'channel': 2}, seed=0)336 crop_1_lt = ops.random_crop(337 self.original_lt, {'probs': 3,338 'channel': 2}, seed=1)339 self.assertEqual(crop_0_lt.axes, crop_1_lt.axes)340 [crop_0, crop_1] = self.eval([crop_0_lt.tensor, crop_1_lt.tensor])341 self.assertFalse((crop_0 == crop_1).all())342 def test_identical_seeds(self):343 crop_0_lt = ops.random_crop(344 self.original_lt, {'probs': 3,345 'channel': 2}, seed=0)346 crop_1_lt = ops.random_crop(347 self.original_lt, {'probs': 3,348 'channel': 2}, seed=0)349 self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)350 def test_crop_idempotent(self):351 crop_0_lt = ops.random_crop(352 self.original_lt, {'probs': 3,353 'channel': 2}, seed=0)354 crop_1_lt = ops.random_crop(crop_0_lt, {'probs': 3, 'channel': 2}, seed=1)355 self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)356 def test_invalid_input(self):357 with self.assertRaises(ValueError):358 ops.random_crop(self.original_lt, {'foobar': 2})359class MapFnTest(Base):360 def test_name(self):361 map_lt = ops.map_fn(core.identity, self.original_lt)362 self.assertIn('lt_map_fn', map_lt.name)363 def test_identity(self):364 map_lt = ops.map_fn(core.identity, self.original_lt)365 self.assertLabeledTensorsEqual(map_lt, self.original_lt)366 def test_callable_object(self):367 class Identity(object):368 def __call__(self, other):369 return other370 map_lt = ops.map_fn(Identity(), self.original_lt)371 self.assertLabeledTensorsEqual(map_lt, self.original_lt)372 def test_slice(self):373 map_lt = ops.map_fn(lambda t: core.slice_function(t, {'channel': 1}),374 self.original_lt)375 slice_lt = core.slice_function(self.original_lt, {'channel': 1})376 self.assertLabeledTensorsEqual(map_lt, slice_lt)377 def test_string(self):378 def fn(entry_lt):379 op = string_ops.string_join([entry_lt, 'world'])380 return core.LabeledTensor(op, [])381 tensor_lt = ops.constant(['hi', 'bye'], axes=['batch'])382 map_lt = ops.map_fn(fn, tensor_lt)383 golden_lt = ops.constant(['hiworld', 'byeworld'], axes=['batch'])384 self.assertLabeledTensorsEqual(map_lt, golden_lt)385class FoldlTest(Base):386 def test_name(self):387 foldl_lt = ops.foldl(core.add, self.original_lt,388 core.slice_function(self.original_lt, {'x': 0}))389 self.assertIn('lt_foldl', foldl_lt.name)390 def test_sum(self):391 initializer_lt = ops.constant([0, 10], axes=['y'])392 tensor_lt = ops.constant([[1, 2], [3, 4], [5, 6]], axes=['x', 'y'])393 foldl_lt = ops.foldl(core.add, tensor_lt, initializer_lt)394 golden_lt = ops.constant([9, 22], axes=['y'])395 self.assertLabeledTensorsEqual(foldl_lt, golden_lt)396class SqueezeTest(Base):397 def setUp(self):398 super(SqueezeTest, self).setUp()399 self.squeezable_lt = core.slice_function(400 self.original_lt, {'channel': slice(0, 1),401 'probs': slice(0, 1)})402 def test_name(self):403 squeeze_lt = ops.squeeze(self.squeezable_lt)404 self.assertIn('lt_squeeze', squeeze_lt.name)405 def test_none(self):406 none_lt = ops.squeeze(self.squeezable_lt, None)407 axes_lt = ops.squeeze(self.squeezable_lt, ['channel', 'probs'])408 self.assertLabeledTensorsEqual(none_lt, axes_lt)409 def test(self):410 squeeze_lt = ops.squeeze(self.squeezable_lt, ['probs'])411 golden_lt = core.slice_function(self.squeezable_lt, {'probs': 0})412 self.assertLabeledTensorsEqual(squeeze_lt, golden_lt)413 def test_invalid_input(self):414 with self.assertRaises(ValueError):415 ops.squeeze(self.original_lt, ['channel'])416 with self.assertRaises(ValueError):417 ops.squeeze(self.squeezable_lt, ['foo'])418class MatMulTest(Base):419 def test_name(self):420 x_lt = core.LabeledTensor(array_ops.ones((3,)), ['x'])421 matmul_lt = ops.matmul(x_lt, x_lt)422 self.assertIn('lt_matmul', matmul_lt.name)423 def test_vector_vector(self):424 x_lt = core.LabeledTensor(math_ops.range(3), ['x'])425 matmul_lt = ops.matmul(x_lt, x_lt)426 golden_lt = core.convert_to_labeled_tensor(5)427 self.assertLabeledTensorsEqual(matmul_lt, golden_lt)428 def test_matrix_vector(self):429 xy_lt = core.LabeledTensor(430 array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y'])431 y_lt = core.LabeledTensor(math_ops.range(3), ['y'])432 matmul_lt = ops.matmul(xy_lt, y_lt)433 golden_lt = core.LabeledTensor(434 math_ops.matmul(xy_lt.tensor, array_ops.reshape(y_lt.tensor,435 (-1, 1)))[:, 0], ['x'])436 self.assertLabeledTensorsEqual(matmul_lt, golden_lt)437 matmul_lt = ops.matmul(y_lt, xy_lt)438 self.assertLabeledTensorsEqual(matmul_lt, golden_lt)439 def test_matrix_matrix(self):440 xy_lt = core.LabeledTensor(441 array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y'])442 yz_lt = core.LabeledTensor(443 array_ops.reshape(math_ops.range(12), (3, 4)), ['y', 'z'])444 matmul_lt = ops.matmul(xy_lt, yz_lt)445 golden_lt = core.LabeledTensor(446 math_ops.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])447 self.assertLabeledTensorsEqual(matmul_lt, golden_lt)448 transpose = lambda x: core.transpose(x, list(x.axes.keys())[::-1])449 matmul_lt = ops.matmul(xy_lt, transpose(yz_lt))450 self.assertLabeledTensorsEqual(matmul_lt, golden_lt)451 matmul_lt = ops.matmul(transpose(xy_lt), yz_lt)452 self.assertLabeledTensorsEqual(matmul_lt, golden_lt)453 matmul_lt = ops.matmul(transpose(xy_lt), transpose(yz_lt))454 self.assertLabeledTensorsEqual(matmul_lt, golden_lt)455 matmul_lt = ops.matmul(yz_lt, xy_lt)456 self.assertLabeledTensorsEqual(matmul_lt, transpose(golden_lt))457 def test_matrix_matrix_axis_order(self):458 xy_lt = core.LabeledTensor(459 array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y'])460 yz_lt = core.LabeledTensor(461 array_ops.reshape(math_ops.range(12), (3, 4)), ['y', 'z'])462 golden_lt = core.LabeledTensor(463 math_ops.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])464 with core.axis_order_scope(['x', 'y', 'z']):465 matmul_lt = ops.matmul(xy_lt, yz_lt)466 self.assertLabeledTensorsEqual(matmul_lt, golden_lt)467 matmul_lt = ops.matmul(yz_lt, xy_lt)468 self.assertLabeledTensorsEqual(matmul_lt, golden_lt)469 def test_invalid(self):470 scalar_lt = core.LabeledTensor(array_ops.ones(()), [])471 x_lt = core.LabeledTensor(array_ops.ones((2,)), ['x'])472 x2_lt = core.LabeledTensor(array_ops.ones((3,)), ['x'])473 y_lt = core.LabeledTensor(array_ops.ones((3,)), ['y'])474 xy_lt = core.LabeledTensor(array_ops.ones((2, 3)), ['x', 'y'])475 xyz_lt = core.LabeledTensor(array_ops.ones((2, 3, 1)), ['x', 'y', 'z'])476 with self.assertRaisesRegexp(ValueError, 'inputs with at least rank'):477 ops.matmul(x_lt, scalar_lt)478 with self.assertRaises(NotImplementedError):479 ops.matmul(x_lt, xyz_lt)480 with self.assertRaisesRegexp(ValueError, 'exactly one axis in common'):481 ops.matmul(x_lt, y_lt)482 with self.assertRaises(NotImplementedError):483 ops.matmul(xy_lt, xy_lt)484 with self.assertRaisesRegexp(ValueError, 'does not match'):485 ops.matmul(x_lt, x2_lt)486class ReduceSumTest(Base):487 def test_name(self):488 sum_lt = ops.reduce_sum(self.original_lt, {'channel'})489 self.assertIn('lt_reduce_sum', sum_lt.name)490 def test_drop_axis(self):491 sum_lt = ops.reduce_sum(self.original_lt, {'channel'})492 golden_lt = core.LabeledTensor(493 math_ops.reduce_sum(self.original_lt.tensor, 1),494 [self.a0, self.a2, self.a3])495 self.assertLabeledTensorsEqual(sum_lt, golden_lt)496 def test_drop_scalar_axis(self):497 sum_lt = ops.reduce_sum(self.original_lt, 'channel')498 golden_lt = core.LabeledTensor(499 math_ops.reduce_sum(self.original_lt.tensor, 1),500 [self.a0, self.a2, self.a3])501 self.assertLabeledTensorsEqual(sum_lt, golden_lt)502 def test_keep_axis(self):503 sum_lt = ops.reduce_sum(self.original_lt, {('channel', 'hihowareyou')})504 golden_lt = core.LabeledTensor(505 math_ops.reduce_sum(506 self.original_lt.tensor, 1, keep_dims=True),507 [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])508 self.assertLabeledTensorsEqual(sum_lt, golden_lt)509 def test_keep_scalar_axis(self):510 sum_lt = ops.reduce_sum(self.original_lt, ('channel', 'hihowareyou'))511 golden_lt = core.LabeledTensor(512 math_ops.reduce_sum(513 self.original_lt.tensor, 1, keep_dims=True),514 [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])515 self.assertLabeledTensorsEqual(sum_lt, golden_lt)516 def test_scalar(self):517 scalar_lt = core.LabeledTensor(constant_op.constant(42), [])518 reduce_lt = ops.reduce_sum(scalar_lt, [])519 self.assertLabeledTensorsEqual(reduce_lt, scalar_lt)520 def test_empty_list(self):521 reduce_lt = ops.reduce_sum(self.original_lt, [])522 self.assertLabeledTensorsEqual(reduce_lt, self.original_lt)523 def test_none(self):524 sum_lt = ops.reduce_sum(self.original_lt)525 golden_lt = core.LabeledTensor(526 math_ops.reduce_sum(self.original_lt.tensor), [])527 self.assertLabeledTensorsEqual(sum_lt, golden_lt)528 def test_function_docstring_and_name(self):529 self.assertIn('tf.reduce_sum', ops.reduce_sum.__doc__)530 self.assertEqual('reduce_sum', ops.reduce_sum.__name__)531class ReduceMeanTest(Base):532 def test_name(self):533 actual_lt = ops.reduce_mean(self.original_lt, {'channel'})534 self.assertIn('lt_reduce_mean', actual_lt.name)535 def test(self):536 actual_lt = ops.reduce_mean(self.original_lt, {'channel'})537 golden_lt = core.LabeledTensor(538 math_ops.reduce_mean(self.original_lt.tensor, 1),539 [self.a0, self.a2, self.a3])540 self.assertLabeledTensorsEqual(actual_lt, golden_lt)541class ReduceProdTest(Base):542 def test_name(self):543 result_lt = ops.reduce_prod(self.original_lt, {'channel'})544 self.assertIn('lt_reduce_prod', result_lt.name)545 def test(self):546 result_lt = ops.reduce_prod(self.original_lt, {'channel'})547 golden_lt = core.LabeledTensor(548 math_ops.reduce_prod(self.original_lt.tensor, 1),549 [self.a0, self.a2, self.a3])550 self.assertLabeledTensorsEqual(result_lt, golden_lt)551class ReduceMinTest(Base):552 def test_name(self):553 result_lt = ops.reduce_min(self.original_lt, {'channel'})554 self.assertIn('lt_reduce_min', result_lt.name)555 def test(self):556 result_lt = ops.reduce_min(self.original_lt, {'channel'})557 golden_lt = core.LabeledTensor(558 math_ops.reduce_min(self.original_lt.tensor, 1),559 [self.a0, self.a2, self.a3])560 self.assertLabeledTensorsEqual(result_lt, golden_lt)561class ReduceMaxTest(Base):562 def test_name(self):563 result_lt = ops.reduce_max(self.original_lt, {'channel'})564 self.assertIn('lt_reduce_max', result_lt.name)565 def test(self):566 result_lt = ops.reduce_max(self.original_lt, {'channel'})567 golden_lt = core.LabeledTensor(568 math_ops.reduce_max(self.original_lt.tensor, 1),569 [self.a0, self.a2, self.a3])570 self.assertLabeledTensorsEqual(result_lt, golden_lt)571class BaseReduceBoolean(Base):572 def setUp(self):573 super(BaseReduceBoolean, self).setUp()574 self.bool_tensor = math_ops.cast(self.original_lt.tensor > 5, dtypes.bool)575 self.bool_lt = core.LabeledTensor(self.bool_tensor, self.original_lt.axes)576class ReduceAllTest(BaseReduceBoolean):577 def test_name(self):578 result_lt = ops.reduce_all(self.bool_lt, {'channel'})579 self.assertIn('lt_reduce_all', result_lt.name)580 def test(self):581 result_lt = ops.reduce_all(self.bool_lt, {'channel'})582 golden_lt = core.LabeledTensor(583 math_ops.reduce_all(self.bool_tensor, 1), [self.a0, self.a2, self.a3])584 self.assertLabeledTensorsEqual(result_lt, golden_lt)585class ReduceAnyTest(BaseReduceBoolean):586 def test_name(self):587 result_lt = ops.reduce_any(self.bool_lt, {'channel'})588 self.assertIn('lt_reduce_any', result_lt.name)589 def test(self):590 result_lt = ops.reduce_any(self.bool_lt, {'channel'})591 golden_lt = core.LabeledTensor(592 math_ops.reduce_any(self.bool_tensor, 1), [self.a0, self.a2, self.a3])593 self.assertLabeledTensorsEqual(result_lt, golden_lt)594class TileTest(Base):595 def test_name(self):596 tile_lt = ops.tile(self.original_lt, {'z': 2})597 self.assertIn('lt_tile', tile_lt.name)598 def test(self):599 for multiple in [2, constant_op.constant(2)]:600 tile_lt = ops.tile(self.original_lt, {'z': multiple})601 golden_op = array_ops.tile(self.original_lt.tensor, [1, 1, multiple, 1])602 golden_axes = [603 'z' if axis.name == 'z' else axis604 for axis in self.original_lt.axes.values()605 ]606 golden_lt = core.LabeledTensor(golden_op, golden_axes)607 self.assertLabeledTensorsEqual(tile_lt, golden_lt)608 def test_invalid_input(self):609 with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):610 ops.tile(self.original_lt, {'foo': 5})611 with self.assertRaisesRegexp(ValueError, 'axes with tick labels'):612 ops.tile(self.original_lt, {'x': 5})613class PadTest(Base):614 def test_name(self):615 pad_lt = ops.pad(self.original_lt,616 {'x': (1, 1),617 'channel': ([], ['alpha'])})618 self.assertIn('lt_pad', pad_lt.name)619 def test(self):620 pad_lt = ops.pad(self.original_lt,621 {'x': (1, 1),622 'channel': ([], ['alpha'])})623 golden_op = array_ops.pad(self.original_lt.tensor, [[1, 1], [0, 1], [0, 0],624 [0, 0]])625 golden_axes = [('x', self.x_size + 2),626 ('channel', ['red', 'green', 'blue', 'alpha']), self.a2,627 self.a3]628 golden_lt = core.LabeledTensor(golden_op, golden_axes)629 self.assertLabeledTensorsEqual(pad_lt, golden_lt)630 def test_invalid_input(self):631 with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):632 ops.pad(self.original_lt, {'foo': (1, 1), 'channel': ([], ['alpha'])})633class ConstantTest(Base):634 def test_name(self):635 constant_lt = ops.constant(1)636 self.assertIn('lt_constant', constant_lt.name)637 def test_scalar(self):638 constant_lt = ops.constant(1)639 golden_lt = core.LabeledTensor(constant_op.constant(1), [])640 self.assertLabeledTensorsEqual(constant_lt, golden_lt)641 def test_infer_shape(self):642 constant_lt = ops.constant([1, 2], axes=['x'])643 golden_lt = core.LabeledTensor(constant_op.constant([1, 2]), ['x'])644 self.assertLabeledTensorsEqual(constant_lt, golden_lt)645 def test_specify_shape(self):646 constant_lt = ops.constant(1, axes=[('x', 3)])647 golden_lt = core.LabeledTensor(constant_op.constant(1, shape=(3,)), ['x'])648 self.assertLabeledTensorsEqual(constant_lt, golden_lt)649 def test_existing_axes(self):650 golden_lt = core.LabeledTensor(constant_op.constant([1, 2]), ['x'])651 constant_lt = ops.constant([1, 2], axes=golden_lt.axes)652 self.assertLabeledTensorsEqual(constant_lt, golden_lt)653class ZerosLikeTest(Base):654 def test_name(self):655 like_lt = ops.zeros_like(self.original_lt)656 self.assertIn('lt_zeros_like', like_lt.name)657 def test(self):658 like_lt = ops.zeros_like(self.original_lt)659 golden_lt = core.LabeledTensor(660 array_ops.zeros_like(self.original_lt.tensor), self.original_lt.axes)661 self.assertLabeledTensorsEqual(like_lt, golden_lt)662class OnesLikeTest(Base):663 def test_name(self):664 like_lt = ops.ones_like(self.original_lt)665 self.assertIn('lt_ones_like', like_lt.name)666 def test(self):667 like_lt = ops.ones_like(self.original_lt)668 golden_lt = core.LabeledTensor(669 array_ops.ones_like(self.original_lt.tensor), self.original_lt.axes)670 self.assertLabeledTensorsEqual(like_lt, golden_lt)671class CastTest(Base):672 def test_name(self):673 cast_lt = ops.cast(self.original_lt, dtypes.float16)674 self.assertIn('lt_cast', cast_lt.name)675 def test(self):676 cast_lt = ops.cast(self.original_lt, dtypes.float16)677 golden_lt = core.LabeledTensor(678 math_ops.cast(self.original_lt.tensor, dtypes.float16),679 self.original_lt.axes)680 self.assertLabeledTensorsEqual(cast_lt, golden_lt)681class VerifyTensorAllFiniteTest(Base):682 def setUp(self):683 super(VerifyTensorAllFiniteTest, self).setUp()684 self.finite_lt = core.LabeledTensor(constant_op.constant(42.0), [])685 self.nan_lt = core.LabeledTensor(constant_op.constant(np.nan), [])686 self.checked_finite_lt = ops.verify_tensor_all_finite(self.finite_lt, '')687 self.checked_nan_lt = ops.verify_tensor_all_finite(self.nan_lt, '')688 def test_name(self):689 self.assertIn('lt_verify_tensor_all_finite', self.checked_finite_lt.name)690 self.assertIn('lt_verify_tensor_all_finite', self.checked_nan_lt.name)691 def test_finite(self):692 self.assertLabeledTensorsEqual(self.finite_lt, self.checked_finite_lt)693 def test_nan(self):694 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,695 'Tensor had NaN values'):696 self.eval([self.checked_nan_lt])697class BooleanMaskTest(Base):698 def test_name(self):699 mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0])700 masked_lt = ops.boolean_mask(self.original_lt, mask)701 self.assertIn('lt_boolean_mask', masked_lt.name)702 def test(self):703 mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0])704 masked_lt = ops.boolean_mask(self.original_lt, mask)705 golden_lt = core.LabeledTensor(706 array_ops.boolean_mask(self.original_lt.tensor, mask.tensor),707 ['x', self.a1, self.a2, self.a3])708 self.assertLabeledTensorsEqual(masked_lt, golden_lt)709 def test_invalid_rank(self):710 mask = core.LabeledTensor(array_ops.ones((7, 3)) > 3, [self.a0, self.a1])711 with self.assertRaises(NotImplementedError):712 ops.boolean_mask(self.original_lt, mask)713 def test_mismatched_axis(self):714 mask = core.LabeledTensor(math_ops.range(7) > 3, ['foo'])715 with self.assertRaisesRegexp(ValueError, 'not equal'):716 ops.boolean_mask(self.original_lt, mask)717class WhereTest(Base):718 def test_name(self):719 condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])720 where_lt = ops.where(condition, condition, condition)721 self.assertIn('lt_where', where_lt.name)722 def test(self):723 condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])724 x = core.LabeledTensor(array_ops.ones(5), ['x'])725 y = core.LabeledTensor(array_ops.zeros(5), ['x'])726 where_lt = ops.where(condition, x, y)727 golden_lt = core.LabeledTensor(728 array_ops.concat([array_ops.ones(3), array_ops.zeros(2)], 0), ['x'])729 self.assertLabeledTensorsEqual(where_lt, golden_lt)730 def test_mismatched_axes(self):731 condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])732 with self.assertRaisesRegexp(ValueError, 'equal axes'):733 ops.where(condition, condition[:3], condition)734 with self.assertRaisesRegexp(ValueError, 'equal axes'):735 ops.where(condition, condition, condition[:3])736if __name__ == '__main__':...
core_test.py
Source:core_test.py
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15from __future__ import absolute_import16from __future__ import division17from __future__ import print_function18import operator19import re20import textwrap21import numpy as np22from six.moves import range # pylint: disable=redefined-builtin23from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc24from tensorflow.contrib.labeled_tensor.python.ops import core25from tensorflow.contrib.labeled_tensor.python.ops import test_util26from tensorflow.python.framework import constant_op27from tensorflow.python.framework import dtypes28from tensorflow.python.framework import ops29from tensorflow.python.framework import tensor_shape30from tensorflow.python.ops import array_ops31from tensorflow.python.ops import math_ops32from tensorflow.python.platform import test as test_lib33class AxisTest(test_lib.TestCase):34 def setUp(self):35 d_7 = tensor_shape.Dimension(7)36 p_rgb = ['red', 'green', 'blue']37 self.i_7 = core.Axis('7', d_7)38 self.i_7p = core.Axis('7prime', d_7)39 self.i_rgb = core.Axis('rgb', p_rgb)40 self.i_range = core.Axis('range', range(7))41 self.i_unknown = core.Axis('unknown', None)42 def test_equality(self):43 axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]44 for i, axis_0 in enumerate(axes):45 for j, axis_1 in enumerate(axes):46 if i == j:47 self.assertEqual(axis_0, axis_1)48 else:49 self.assertNotEqual(axis_0, axis_1)50 def test_axis_value(self):51 self.assertEqual(self.i_7.value, tensor_shape.Dimension(7))52 self.assertTrue(self.i_range.value == tuple(range(7)))53 def test_axis_input(self):54 axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]55 for axis in axes:56 self.assertEqual(axis, core.Axis(axis.name, axis.value))57 def test_axis_value_input(self):58 axis = self.i_range59 for value in [range(7), list(range(7)), np.arange(7)]:60 self.assertEqual(axis, core.Axis(axis.name, value))61 def test_size(self):62 self.assertEqual(len(self.i_7), 7)63 self.assertEqual(len(self.i_rgb), 3)64 self.assertEqual(len(self.i_range), 7)65 self.assertEqual(self.i_unknown.size, None)66 def test_concat_single(self):67 red = core.Axis('rgb', ['red'])68 self.assertEqual(core.concat_axes([red]), red)69 def test_concat_many(self):70 red = core.Axis('rgb', ['red'])71 green = core.Axis('rgb', ['green'])72 blue = core.Axis('rgb', ['blue'])73 red_green_blue = core.Axis('rgb', ['red', 'green', 'blue'])74 self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue)75 def test_concat_different_names(self):76 red = core.Axis('red', ['red'])77 green = core.Axis('green', ['red'])78 with self.assertRaises(ValueError):79 core.concat_axes([red, green])80 def test_concat_unknown(self):81 red = core.Axis('rgb', None)82 green = core.Axis('rgb', None)83 self.assertEqual(core.concat_axes([red, green]), red)84 def test_repr(self):85 self.assertEqual("Axis('7', Dimension(7))", repr(self.i_7))86 def test_invalid_input(self):87 with self.assertRaises(TypeError):88 core.Axis('foo', [{}])89 with self.assertRaises(ValueError):90 core.Axis('foo', [1, 2, 3, 1])91 red = core.Axis('foo', ['red'])92 with self.assertRaises(tc.Error):93 core.concat_axes([red, 1])94 def test_as_axis(self):95 self.assertEqual(self.i_7, core.as_axis(('7', 7)))96 self.assertEqual(self.i_7, core.as_axis(self.i_7))97class AxesTest(test_lib.TestCase):98 def setUp(self):99 d_7 = tensor_shape.Dimension(7)100 d_8 = tensor_shape.Dimension(8)101 p_rgb = ['red', 'green', 'blue']102 p_range = range(7)103 self.i_8 = core.Axis('8', d_8)104 self.a0 = core.Axes([('d7', d_7)])105 self.a1 = core.Axes([('d7', d_7)])106 self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)])107 self.a3 = core.Axes([('8', d_8), ('range', p_range)])108 def test_equality(self):109 self.assertEqual(self.a0, self.a0)110 self.assertEqual(self.a0, self.a1)111 self.assertNotEqual(self.a0, self.a2)112 def test_repr(self):113 self.assertEqual("Axes([('d7', Dimension(7))])", repr(self.a0))114 def test_remove(self):115 a = self.a3.remove('range')116 self.assertEqual(a, core.Axes([self.i_8]))117 with self.assertRaises(KeyError):118 self.a3.remove('foobar')119 def test_typecheck_error_message(self):120 pattern = ('List(Union(labeled_tensor.Axis, Tuple(..., '121 'Union(Union(numpy.ndarray, %s, list, tuple), '122 'Optional(Union(tensorflow.Dimension, int))))))' %123 range.__name__)124 regexp = re.escape(pattern).replace(re.escape('...'), '.*')125 with self.assertRaisesRegexp(tc.Error, 'allowed type ' + regexp):126 core.Axes(None)127class LabeledTensorTest(test_util.Base):128 def setUp(self):129 tensor = array_ops.ones([7, 3, 8, 1])130 a0 = ('x', range(7))131 a1 = ('channel', ['red', 'green', 'blue'])132 a2 = ('y', 8)133 a3 = ('z', tensor_shape.Dimension(1))134 self.lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])135 def test_repr(self):136 pattern = textwrap.dedent("""\137 <LabeledTensor '...' shape=(7, 3, 8, 1) dtype=float32138 axes=[('x', ...),139 ('channel', ...),140 ('y', Dimension(8)),141 ('z', Dimension(1))]>""")142 regexp = re.escape(pattern).replace(re.escape('...'), '.*')143 self.assertRegexpMatches(repr(self.lt), regexp)144 def test_reuse_existing_axes(self):145 alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes)146 self.assertLabeledTensorsEqual(alt_lt, self.lt)147 def test_reuse_existing_axis_objects(self):148 alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes.values())149 self.assertLabeledTensorsEqual(alt_lt, self.lt)150 def test_indexing_scalars(self):151 actual = self.lt[:, :, :, 0]152 expected = core.LabeledTensor(self.lt.tensor[:, :, :, 0],153 list(self.lt.axes.values())[:-1])154 self.assertLabeledTensorsEqual(actual, expected)155 actual = self.lt[1, :, :, 0]156 expected = core.LabeledTensor(self.lt.tensor[1, :, :, 0],157 list(self.lt.axes.values())[1:-1])158 self.assertLabeledTensorsEqual(actual, expected)159 actual = self.lt[1, 2, :, 0]160 expected = core.LabeledTensor(self.lt.tensor[1, 2, :, 0],161 list(self.lt.axes.values())[2:-1])162 self.assertLabeledTensorsEqual(actual, expected)163 def test_indexing_1d(self):164 lt_1d = self.lt[1, 2, :, 0]165 actual = lt_1d[3]166 expected = core.LabeledTensor(lt_1d.tensor[3], [])167 self.assertLabeledTensorsEqual(actual, expected)168 def test_indexing_slices(self):169 actual = self.lt[:3, :, :, :]170 axes = [('x', range(3))] + list(self.lt.axes.values())[1:]171 expected = core.LabeledTensor(self.lt.tensor[:3, :, :, :], axes)172 self.assertLabeledTensorsEqual(actual, expected)173 def test_invalid_indexing(self):174 with self.assertRaises(ValueError):175 self.lt[0] # pylint: disable=pointless-statement176 with self.assertRaises(ValueError):177 self.lt[:, :, :, :, 0] # pylint: disable=pointless-statement178 def test_unknown_size(self):179 tensor = array_ops.placeholder(dtypes.string, [None])180 actual = core.LabeledTensor(tensor, ['x'])181 self.assertIsNone(actual.axes['x'].size)182 self.assertIsNone(actual.axes['x'].value.value)183 def test_eq(self):184 self.assertEqual(self.lt, self.lt)185 self.assertNotEqual(self.lt, self.lt.tensor)186 self.assertNotEqual(self.lt.tensor, self.lt)187 def test_hash(self):188 lt1 = self.lt189 lt2 = core.LabeledTensor(self.lt.tensor, self.lt.axes)190 self.assertEqual(lt1, lt2)191 self.assertEqual(hash(lt1), hash(lt2))192 def test_name(self):193 self.assertEqual(self.lt.name, self.lt.tensor.name)194 def test_dtype(self):195 self.assertEqual(self.lt.dtype, self.lt.tensor.dtype)196 def test_shape(self):197 self.assertEqual(self.lt.shape, self.lt.tensor.shape)198 def test_get_shape(self):199 self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape())200 def test_convert_to_tensor(self):201 expected = self.lt.tensor202 actual = ops.convert_to_tensor(self.lt)203 self.assertIs(expected, actual)204class Base(test_util.Base):205 def setUp(self):206 self.x_size = 7207 self.channel_size = 3208 self.z_size = 4209 self.probs_size = 11210 tensor = math_ops.range(0, self.x_size * self.channel_size * self.z_size *211 self.probs_size)212 tensor = array_ops.reshape(213 tensor, [self.x_size, self.channel_size, self.z_size, self.probs_size])214 a0 = ('x', range(self.x_size))215 a1 = ('channel', ['red', 'green', 'blue'])216 a2 = 'z'217 a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))218 self.tensor = tensor219 self.a0 = a0220 self.a1 = a1221 self.a2 = a2222 self.a3 = a3223 self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])224 self.x_probs_lt = core.slice_function(self.original_lt,225 {'z': 0,226 'channel': 0})227 self.channel_probs_lt = core.slice_function(self.original_lt,228 {'x': 3,229 'z': 0})230class IdentityTest(Base):231 def test_name(self):232 identity_lt = core.identity(self.original_lt)233 self.assertIn('lt_identity', identity_lt.name)234class SliceFunctionTest(Base):235 def test_name(self):236 select_lt = core.slice_function(self.original_lt, {'channel': 1})237 self.assertIn('lt_slice', select_lt.name)238 def test_scalar(self):239 select_lt = core.slice_function(self.original_lt, {'channel': 1})240 golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :],241 [self.a0, self.a2, self.a3])242 self.assertLabeledTensorsEqual(select_lt, golden_lt)243 def test_slice(self):244 select_lt = core.slice_function(self.original_lt, {'channel': slice(0, 2)})245 a1_sliced = ('channel', ['red', 'green'])246 golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],247 [self.a0, a1_sliced, self.a2, self.a3])248 self.assertLabeledTensorsEqual(select_lt, golden_lt)249 def test_slices(self):250 select_lt = core.slice_function(251 self.original_lt, {'x': slice(1, 5),252 'channel': slice(1, None)})253 a0_sliced = ('x', range(1, 5))254 a1_sliced = ('channel', ['green', 'blue'])255 golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :],256 [a0_sliced, a1_sliced, self.a2, self.a3])257 self.assertLabeledTensorsEqual(select_lt, golden_lt)258 def test_slice_unlabeled(self):259 select_lt = core.slice_function(self.original_lt, {'z': slice(1, 3)})260 a2_sliced = 'z'261 golden_lt = core.LabeledTensor(self.tensor[:, :, 1:3, :],262 [self.a0, self.a1, a2_sliced, self.a3])263 self.assertLabeledTensorsEqual(select_lt, golden_lt)264 def test_slice_unknown_shape(self):265 lt = core.LabeledTensor(266 array_ops.placeholder(dtypes.float32, [None, 1]), ['x', 'y'])267 sliced_lt = core.slice_function(lt, {'y': 0})268 self.assertEqual(list(sliced_lt.axes.values()), [lt.axes['x']])269class TransposeTest(Base):270 def test_name(self):271 transpose_lt = core.transpose(self.original_lt,272 self.original_lt.axes.keys())273 self.assertIn('lt_transpose', transpose_lt.name)274 def test_identity(self):275 transpose_lt = core.transpose(self.original_lt,276 self.original_lt.axes.keys())277 golden_lt = self.original_lt278 self.assertLabeledTensorsEqual(transpose_lt, golden_lt)279 def test(self):280 transpose_lt = core.transpose(self.original_lt,281 ['z', 'channel', 'x', 'probs'])282 golden_lt = core.LabeledTensor(283 array_ops.transpose(self.tensor, [2, 1, 0, 3]),284 [self.a2, self.a1, self.a0, self.a3])285 self.assertLabeledTensorsEqual(transpose_lt, golden_lt)286 def test_default_axis_order(self):287 transpose_lt = core.transpose(self.original_lt)288 golden_lt = core.LabeledTensor(289 array_ops.transpose(self.tensor, [3, 2, 1, 0]),290 list(reversed(list(self.original_lt.axes.values()))))291 self.assertLabeledTensorsEqual(transpose_lt, golden_lt)292 def test_invalid_input(self):293 with self.assertRaises(ValueError):294 core.transpose(self.original_lt, ['channel', 'x', 'probs'])295 with self.assertRaises(ValueError):296 core.transpose(self.original_lt, ['z', 'foo', 'x', 'probs'])297class ExpandDimsTest(Base):298 def test_name(self):299 expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())300 self.assertIn('lt_expand', expand_lt.name)301 def test_identity(self):302 expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())303 golden_lt = self.original_lt304 self.assertLabeledTensorsEqual(expand_lt, golden_lt)305 def test(self):306 expand_lt = core.expand_dims(307 self.original_lt, ['foo', 'x', 'bar', 'channel', 'z', 'probs', 'grok'])308 golden_lt = core.LabeledTensor(309 array_ops.reshape(self.tensor, [310 1, self.x_size, 1, self.channel_size, self.z_size, self.probs_size,311 1312 ]), ['foo', self.a0, 'bar', self.a1, self.a2, self.a3, 'grok'])313 self.assertLabeledTensorsEqual(expand_lt, golden_lt)314 def test_label(self):315 expand_lt = core.expand_dims(self.original_lt, [316 'x',317 'channel',318 ('foo', 'bar'),319 'z',320 'probs',321 ])322 golden_lt = core.LabeledTensor(323 array_ops.reshape(324 self.tensor,325 [self.x_size, self.channel_size, 1, self.z_size, self.probs_size]),326 [self.a0, self.a1, ('foo', ['bar']), self.a2, self.a3])327 self.assertLabeledTensorsEqual(expand_lt, golden_lt)328 def test_unknown_dimension(self):329 orig_lt = core.LabeledTensor(330 array_ops.placeholder(dtypes.float32, [None]), ['x'])331 expand_lt = core.expand_dims(orig_lt, ['x', 'y'])332 self.assertEqual(expand_lt.axes, core.Axes([('x', None), ('y', 1)]))333 def test_invalid_input(self):334 with self.assertRaises(core.AxisOrderError):335 core.expand_dims(self.original_lt,336 ['foo', 'not_x', 'bar', 'channel', 'z', 'probs', 'grok'])337 with self.assertRaises(core.AxisOrderError):338 core.expand_dims(self.original_lt,339 ['foo', 'z', 'bar', 'channel', 'x', 'probs', 'grok'])340class AxisOrderScopeTest(Base):341 def test(self):342 xyz = ['x', 'y', 'z']343 abc = ['a', 'b', 'c']344 self.assertIsNone(core.get_axis_order())345 with core.axis_order_scope(xyz):346 self.assertEqual(core.get_axis_order(), xyz)347 with core.axis_order_scope():348 self.assertIsNone(core.get_axis_order())349 with core.axis_order_scope(abc):350 self.assertEqual(core.get_axis_order(), abc)351 self.assertIsNone(core.get_axis_order())352 self.assertEqual(core.get_axis_order(), xyz)353 self.assertIsNone(core.get_axis_order())354class CheckAxisOrderTest(Base):355 def test_passes(self):356 axis_order = ['w', 'x', 'y', 'z']357 lt = core.LabeledTensor(array_ops.ones((1, 1, 1, 1)), axis_order)358 core.check_axis_order(lt, axis_order)359 lt = core.LabeledTensor(array_ops.ones((1, 1, 1)), axis_order[1:])360 core.check_axis_order(lt, axis_order)361 lt = core.LabeledTensor(array_ops.ones((1, 1, 1)), axis_order[:-1])362 core.check_axis_order(lt, axis_order)363 def test_invalid(self):364 axis_order = ['w', 'x', 'y', 'z']365 lt = core.LabeledTensor(array_ops.ones((1, 1, 1, 1)), axis_order)366 with self.assertRaises(core.AxisOrderError):367 core.check_axis_order(lt)368 with self.assertRaises(core.AxisOrderError):369 core.check_axis_order(lt, axis_order[:-1])370 with self.assertRaises(core.AxisOrderError):371 core.check_axis_order(lt, axis_order[::-1])372 def test_scope(self):373 axis_order = ['w', 'x', 'y', 'z']374 lt = core.LabeledTensor(array_ops.ones((1, 1, 1, 1)), axis_order)375 with core.axis_order_scope(axis_order):376 core.check_axis_order(lt)377class ImposeAxisOrderTest(Base):378 def test_identity(self):379 axis_order = ['w', 'x', 'y', 'z']380 lt = core.LabeledTensor(381 array_ops.reshape(math_ops.range(24), (1, 2, 3, 4)), axis_order)382 actual = core.impose_axis_order(lt, axis_order)383 self.assertLabeledTensorsEqual(lt, actual)384 lt = core.LabeledTensor(385 array_ops.reshape(math_ops.range(6), (1, 2, 3)), axis_order[:3])386 actual = core.impose_axis_order(lt, axis_order)387 self.assertLabeledTensorsEqual(lt, actual)388 def test_reverse(self):389 axis_order = ['w', 'x', 'y', 'z']390 lt = core.LabeledTensor(391 array_ops.reshape(math_ops.range(24), (1, 2, 3, 4)), axis_order)392 actual = core.impose_axis_order(lt, axis_order[::-1])393 expected = core.transpose(lt, axis_order[::-1])394 self.assertLabeledTensorsEqual(expected, actual)395 lt = core.LabeledTensor(396 array_ops.reshape(math_ops.range(6), (1, 2, 3)), axis_order[:3])397 actual = core.impose_axis_order(lt, axis_order[::-1])398 expected = core.transpose(lt, ['y', 'x', 'w'])399 self.assertLabeledTensorsEqual(expected, actual)400 def test_scope(self):401 axis_order = ['w', 'x', 'y', 'z']402 lt = core.LabeledTensor(403 array_ops.reshape(math_ops.range(24), (1, 2, 3, 4)), axis_order)404 expected = core.transpose(lt, axis_order[::-1])405 with core.axis_order_scope(axis_order[::-1]):406 actual = core.impose_axis_order(lt)407 self.assertLabeledTensorsEqual(expected, actual)408 def test_invalid(self):409 lt = core.LabeledTensor(410 array_ops.reshape(math_ops.range(2), (1, 2)), ['x', 'y'])411 with self.assertRaises(ValueError):412 core.impose_axis_order(lt)413 with self.assertRaises(ValueError):414 core.impose_axis_order(lt, ['x'])415class FindConsistentOrderingTest(Base):416 def test(self):417 cases = [418 ([], [], []),419 (['x'], [], ['x']),420 ([], ['x'], ['x']),421 (['x'], ['x'], ['x']),422 (['x'], ['y'], ['x', 'y']),423 (['y'], ['x'], ['y', 'x']),424 (['x', 'y'], ['x', 'y'], ['x', 'y']),425 (['x', 'y'], ['y', 'x'], None),426 (['x', 'y'], ['y', 'z'], ['x', 'y', 'z']),427 (['x', 'z'], ['y', 'z'], ['x', 'y', 'z']),428 (['x', 'y'], ['x', 'z'], ['x', 'y', 'z']),429 (['w', 'x'], ['y', 'z'], ['w', 'x', 'y', 'z']),430 (['x', 'y', 'z'], ['z', 'x'], None),431 (['x', 'y', 'z'], ['x'], ['x', 'y', 'z']),432 ([], ['x', 'y', 'z'], ['x', 'y', 'z']),433 ]434 for a, b, expected in cases:435 actual = core._find_consistent_ordering(a, b)436 msg = ('unexpected ordering between %r and %r:\nexpected: %r\nactual: %r'437 % (a, b, expected, actual))438 self.assertEqual(expected, actual, msg=msg)439class AlignTest(Base):440 def test_name(self):441 align_lt_0, align_lt_1, _ = core.align(self.original_lt, self.original_lt)442 self.assertIn('lt_align', align_lt_0.name)443 self.assertIn('/0', align_lt_0.name)444 self.assertIn('lt_align', align_lt_1.name)445 self.assertIn('/1', align_lt_1.name)446 def test_identical_shaped_inputs(self):447 offset_tensor = self.original_lt.tensor + 1448 offset_lt = core.LabeledTensor(offset_tensor, self.original_lt.axes)449 align_lt, align_offset_lt, broadcast_axes = core.align(self.original_lt,450 offset_lt)451 self.assertLabeledTensorsEqual(align_lt, self.original_lt)452 self.assertLabeledTensorsEqual(align_offset_lt, offset_lt)453 self.assertEqual(broadcast_axes, self.original_lt.axes)454 def test_different_inputs(self):455 # The correct axis ordering is ['x', 'channel', 'probs'].456 align_x_probs_lt, align_channel_probs_lt, broadcast_axes = core.align(457 self.x_probs_lt, self.channel_probs_lt)458 x_probs_golden_lt = core.LabeledTensor(459 array_ops.reshape(self.x_probs_lt.tensor,460 [self.x_size, 1, self.probs_size]),461 [self.a0, 'channel', self.a3])462 self.assertLabeledTensorsEqual(align_x_probs_lt, x_probs_golden_lt)463 channel_probs_golden_lt = core.LabeledTensor(464 array_ops.reshape(self.channel_probs_lt.tensor,465 [1, self.channel_size, self.probs_size]),466 ['x', self.a1, self.a3])467 self.assertLabeledTensorsEqual(align_channel_probs_lt,468 channel_probs_golden_lt)469 self.assertEqual(broadcast_axes, core.Axes([self.a0, self.a1, self.a3]))470 def test_axis_order_scope(self):471 xz_lt = core.LabeledTensor(array_ops.ones((2, 3)), ['x', 'z'])472 yz_lt = core.LabeledTensor(array_ops.ones((4, 3)), ['y', 'z'])473 _, _, broadcast_axes = core.align(xz_lt, yz_lt)474 self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])475 _, _, broadcast_axes = core.align(yz_lt, xz_lt)476 self.assertEqual(list(broadcast_axes.keys()), ['y', 'x', 'z'])477 with core.axis_order_scope(['x', 'y', 'z']):478 _, _, broadcast_axes = core.align(yz_lt, xz_lt)479 self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])480 with core.axis_order_scope(['x', 'y']):481 with self.assertRaises(core.AxisOrderError):482 core.align(xz_lt, yz_lt)483 with self.assertRaises(core.AxisOrderError):484 core.align(yz_lt, xz_lt)485 def test_invalid_input(self):486 lt_0 = core.LabeledTensor(array_ops.zeros([5]), [('a', range(5))])487 lt_1 = core.LabeledTensor(array_ops.zeros([5]), [('a', range(1, 6))])488 with self.assertRaises(ValueError):489 core.align(lt_0, lt_1)490class ConvertToLabeledTensorTest(Base):491 # TODO(shoyer): Simplify these tests once we can reuse labeled tensors in492 # assertLabeledTensorsEqual.493 def test_labeled_tensor(self):494 actual = core.convert_to_labeled_tensor(self.original_lt)495 self.assertLabeledTensorsEqual(actual, self.original_lt)496 def test_python_scalar(self):497 actual = core.convert_to_labeled_tensor(42)498 golden_lt = core.LabeledTensor(ops.convert_to_tensor(42), [])499 self.assertLabeledTensorsEqual(actual, golden_lt)500 def test_numpy_array(self):501 actual = core.convert_to_labeled_tensor(np.array(42))502 golden_lt = core.LabeledTensor(ops.convert_to_tensor(42), [])503 self.assertLabeledTensorsEqual(actual, golden_lt)504 def test_tensor(self):505 actual = core.convert_to_labeled_tensor(constant_op.constant(42))506 golden_lt = core.LabeledTensor(ops.convert_to_tensor(42), [])507 self.assertLabeledTensorsEqual(actual, golden_lt)508 def test_invalid_input(self):509 with self.assertRaises(ValueError):510 core.convert_to_labeled_tensor(math_ops.range(5))511 with self.assertRaises(ValueError):512 core.convert_to_labeled_tensor(np.array([1, 2]))513class DocStringCheckMixin(object):514 # requires self.ops to be defined515 def test_function_docstring_and_name(self):516 for op_name, _, _, lt_op in self.ops:517 if lt_op is not None:518 self.assertIn('tf.%s' % op_name, lt_op.__doc__)519 self.assertEqual(op_name, lt_op.__name__)520class UnaryOpsTestsMixin(object):521 # requires self.ops and self.test_lt to be defined522 def test_core_op(self):523 for op_name, _, tf_op, lt_op in self.ops:524 if tf_op is not None:525 golden_lt = core.LabeledTensor(526 tf_op(self.test_lt.tensor), self.test_lt.axes)527 actual_lt = lt_op(self.test_lt)528 self.assertIn(op_name, actual_lt.name)529 self.assertLabeledTensorsEqual(golden_lt, actual_lt)530 def test_infix(self):531 for op_name, infix_op, _, _ in self.ops:532 if infix_op is not None:533 expected_lt = core.LabeledTensor(534 infix_op(self.test_lt.tensor), self.test_lt.axes)535 actual_lt = infix_op(self.test_lt)536 self.assertIn(op_name, actual_lt.name)537 self.assertLabeledTensorsEqual(expected_lt, actual_lt)538class CoreUnaryOpsTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):539 def setUp(self):540 super(CoreUnaryOpsTest, self).setUp()541 self.ops = [542 ('abs', operator.abs, math_ops.abs, core.abs_function),543 ('neg', operator.neg, math_ops.negative, core.neg),544 # TODO(shoyer): add unary + to core TensorFlow545 ('pos', None, None, None),546 ('sign', None, math_ops.sign, core.sign),547 ('reciprocal', None, math_ops.reciprocal, core.reciprocal),548 ('square', None, math_ops.square, core.square),549 ('round', None, math_ops.round, core.round_function),550 ('sqrt', None, math_ops.sqrt, core.sqrt),551 ('rsqrt', None, math_ops.rsqrt, core.rsqrt),552 ('log', None, math_ops.log, core.log),553 ('exp', None, math_ops.exp, core.exp),554 ('log', None, math_ops.log, core.log),555 ('ceil', None, math_ops.ceil, core.ceil),556 ('floor', None, math_ops.floor, core.floor),557 ('cos', None, math_ops.cos, core.cos),558 ('sin', None, math_ops.sin, core.sin),559 ('tan', None, math_ops.tan, core.tan),560 ('acos', None, math_ops.acos, core.acos),561 ('asin', None, math_ops.asin, core.asin),562 ('atan', None, math_ops.atan, core.atan),563 ('lgamma', None, math_ops.lgamma, core.lgamma),564 ('digamma', None, math_ops.digamma, core.digamma),565 ('erf', None, math_ops.erf, core.erf),566 ('erfc', None, math_ops.erfc, core.erfc),567 ('lgamma', None, math_ops.lgamma, core.lgamma),568 ]569 total_size = np.prod([v.size for v in self.original_lt.axes.values()])570 self.test_lt = core.LabeledTensor(571 math_ops.cast(self.original_lt, dtypes.float32) / total_size,572 self.original_lt.axes)573class LogicalNotTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):574 def setUp(self):575 super(LogicalNotTest, self).setUp()576 self.ops = [('logical_not', operator.invert, math_ops.logical_not,577 core.logical_not),]578 self.test_lt = self.original_lt < 10579class BinaryOpsTestsMixin(object):580 # requires self.ops, self.test_lt_1, self.test_lt_2, self.test_lt_1_broadcast581 # and self.test_lt_2_broadcast to be defined582 def test_core_op(self):583 for op_name, _, tf_op, lt_op in self.ops:584 golden_tensor = tf_op(self.test_lt_1_broadcast, self.test_lt_2_broadcast)585 golden_lt = core.LabeledTensor(golden_tensor, self.broadcast_axes)586 actual_lt = lt_op(self.test_lt_1, self.test_lt_2)587 self.assertIn(op_name, actual_lt.name)588 self.assertLabeledTensorsEqual(golden_lt, actual_lt)589 def test_infix(self):590 for op_name, infix_op, _, lt_op in self.ops:591 if infix_op is not None:592 expected_lt = lt_op(self.test_lt_1, self.test_lt_2)593 actual_lt = infix_op(self.test_lt_1, self.test_lt_2)594 self.assertIn(op_name, actual_lt.name)595 self.assertLabeledTensorsEqual(expected_lt, actual_lt)596class CoreBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):597 def setUp(self):598 super(CoreBinaryOpsTest, self).setUp()599 self.x_probs_broadcast_tensor = array_ops.reshape(600 self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size])601 self.channel_probs_broadcast_tensor = array_ops.reshape(602 self.channel_probs_lt.tensor, [1, self.channel_size, self.probs_size])603 # == and != are not element-wise for tf.Tensor, so they shouldn't be604 # elementwise for LabeledTensor, either.605 self.ops = [606 ('add', operator.add, math_ops.add, core.add),607 ('sub', operator.sub, math_ops.subtract, core.sub),608 ('mul', operator.mul, math_ops.multiply, core.mul),609 ('div', operator.truediv, math_ops.div, core.div),610 ('mod', operator.mod, math_ops.mod, core.mod),611 ('pow', operator.pow, math_ops.pow, core.pow_function),612 ('equal', None, math_ops.equal, core.equal),613 ('less', operator.lt, math_ops.less, core.less),614 ('less_equal', operator.le, math_ops.less_equal, core.less_equal),615 ('not_equal', None, math_ops.not_equal, core.not_equal),616 ('greater', operator.gt, math_ops.greater, core.greater),617 ('greater_equal', operator.ge, math_ops.greater_equal,618 core.greater_equal),619 ]620 self.test_lt_1 = self.x_probs_lt621 self.test_lt_2 = self.channel_probs_lt622 self.test_lt_1_broadcast = self.x_probs_broadcast_tensor623 self.test_lt_2_broadcast = self.channel_probs_broadcast_tensor624 self.broadcast_axes = [self.a0, self.a1, self.a3]625 def test_reflexive(self):626 labeled_tensor = self.x_probs_lt + 1 # all elements must be >0 for division627 for op_name, infix_op, _, lt_op in self.ops:628 if infix_op is not None:629 expected_lt = lt_op(2, labeled_tensor)630 actual_lt = infix_op(2, labeled_tensor)631 # Python uses greater for the reflexive version of less (and vise-versa)632 if 'less' in op_name:633 op_name = op_name.replace('less', 'greater')634 elif 'greater' in op_name:635 op_name = op_name.replace('greater', 'less')636 self.assertIn(op_name, actual_lt.name)637 self.assertLabeledTensorsEqual(expected_lt, actual_lt)638class LogicalBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):639 def setUp(self):640 super(LogicalBinaryOpsTest, self).setUp()641 self.ops = [642 ('logical_and', operator.and_, math_ops.logical_and, core.logical_and),643 ('logical_or', operator.or_, math_ops.logical_or, core.logical_or),644 ('logical_xor', operator.xor, math_ops.logical_xor, core.logical_xor),645 ]646 self.test_lt_1 = self.original_lt < 10647 self.test_lt_2 = self.original_lt < 5648 self.test_lt_1_broadcast = self.test_lt_1.tensor649 self.test_lt_2_broadcast = self.test_lt_2.tensor650 self.broadcast_axes = self.test_lt_1.axes651class FloatBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):652 def setUp(self):653 super(FloatBinaryOpsTest, self).setUp()654 self.ops = [655 ('igamma', None, math_ops.igamma, core.igamma),656 ('igammac', None, math_ops.igammac, core.igammac),657 ('zeta', None, math_ops.zeta, core.zeta),658 ('polygamma', None, math_ops.polygamma, core.polygamma),659 ('maximum', None, math_ops.maximum, core.maximum),660 ('minimum', None, math_ops.minimum, core.minimum),661 ('squared_difference', None, math_ops.squared_difference,662 core.squared_difference),663 ]664 total_size = np.prod([v.size for v in self.original_lt.axes.values()])665 test_lt = core.LabeledTensor(666 math_ops.cast(self.original_lt, dtypes.float32) / total_size,667 self.original_lt.axes)668 self.test_lt_1 = test_lt669 self.test_lt_2 = 1.0 - test_lt670 self.test_lt_1_broadcast = self.test_lt_1.tensor671 self.test_lt_2_broadcast = self.test_lt_2.tensor672 self.broadcast_axes = self.test_lt_1.axes673if __name__ == '__main__':...
setup.py
Source:setup.py
1# coding: utf-82"""3 Telstra Messaging API4 # Introduction Send and receive SMS and MMS messages globally using Telstraââ¬â¢s enterprise grade Messaging API. It also allows your application to track the delivery status of both sent and received messages. Get your dedicated Australian number, and start sending and receiving messages today. # Features <p>The Telstra Messaging API provides the features below. <table> <thead> <tr> <th>Feature</th> <th>Description</th> </tr> </thead> <tbody> <tr> <td><code>Dedicated Number</code></td> <td>Provision a mobile number for your account to be used as from address in the API</td> </tr> <tr> <td><code>Send Messages</code></td> <td>Sending SMS or MMS messages</td> </tr> <tr> <td><code>Receive Messages</code></td> <td>Telstra will deliver messages sent to a dedicated number or to the <code>notifyURL</code> defined by you</td> </tr> <tr> <td><code>Broadcast Messages</code></td> <td>Invoke a single API to send a message to a list of number provided in <code>to</code></td> </tr> <tr> <td><code>Delivery Status</code></td> <td>Query the delivery status of your messages</td> </tr> <tr> <td><code>Callbacks</code></td> <td>Provide a notification URL and Telstra will notify your app when messages status changes</td> </tr> <tr> <td><code>Alphanumeric Identifier</code></td> <td>Differentiate yourself by providing an alphanumeric string in <code>from</code>. This feature is only available on paid plans</td> </tr> <tr> <td><code>Concatenation</code></td> <td>Send messages up to 1900 characters long and Telstra will automaticaly segment and reassemble them</td> </tr> <tr> <td><code>Reply Request</code></td> <td>Create a chat session by associating <code>messageId</code> and <code>to</code> number to track responses received from a mobile number. We will store this association for 8 days</td> </tr> <tr> <td><code>Character set</code></td> <td>Accepts all Unicode characters as part of UTF-8</td> </tr> <tr> <td><code>Bounce-back response</code></td> <td>See if your SMS hits an unreachable or unallocated number (Australia Only)</td> </tr> <tr> <td><code>Queuing</code></td> <td>Messaging API will automatically queue and deliver each message at a compliant rate.</td> </tr> <tr> <td><code>Emoji Encoding</code></td> <td>The API supports the encoding of the full range of emojis. Emojis in the reply messages will be in their UTF-8 format.</td> </tr> </tbody> </table> # Getting Access to the API <ol> <li>Register at <a href=\"https://dev.telstra.com\">https://dev.telstra.com</a>. <li>After registration, login to <a href=\"https://dev.telstra.com\">https://dev.telstra.com</a> and navigate to the "My apps" page. <li>Create your application by clicking the "Add new app" button <li>Select "API Free Trial" Product when configuring your application. This Product includes the Telstra Messaging API as well as other APIs. Your application will be approved automatically. <li>There is a maximum of 1000 free messages per developer. Additional messages and features can be purchased from <a href=\"https://dev.telstra.com\">https://dev.telstra.com</a>. <li>Note your <code>Client key</code> and <code>Client secret</code> as these will be needed to provision a number for your application and for authentication. </ol> <p>Now head over to <b>Getting Started</b> where you can find a postman collection as well as some links to sample apps and SDKs to get you started. <p>Happy Messaging! # Getting Started <p>Below are the steps to get started with the Telstra Messaging API.</p> <ol> <li>Generate OAuth2 Token using your <code>Client key</code> and <code>Client secret</code>.</li> <li>Create Subscription in order to receive a provisioned number.</li> <li>Send Message to a specific mobile number.</li> </ol> <h2>Run in Postman</h2> <p><a href=\"https://app.getpostman.com/run-collection/ded00578f69a9deba256#?env%5BMessaging%20API%20Environments%5D=W3siZW5hYmxlZCI6dHJ1ZSwia2V5IjoiY2xpZW50X2lkIiwidmFsdWUiOiIiLCJ0eXBlIjoidGV4dCJ9LHsiZW5hYmxlZCI6dHJ1ZSwia2V5IjoiY2xpZW50X3NlY3JldCIsInZhbHVlIjoiIiwidHlwZSI6InRleHQifSx7ImVuYWJsZWQiOnRydWUsImtleSI6ImFjY2Vzc190b2tlbiIsInZhbHVlIjoiIiwidHlwZSI6InRleHQifSx7ImVuYWJsZWQiOnRydWUsImtleSI6Imhvc3QiLCJ2YWx1ZSI6InRhcGkudGVsc3RyYS5jb20iLCJ0eXBlIjoidGV4dCJ9LHsiZW5hYmxlZCI6dHJ1ZSwia2V5IjoiQXV0aG9yaXphdGlvbiIsInZhbHVlIjoiIiwidHlwZSI6InRleHQifSx7ImVuYWJsZWQiOnRydWUsImtleSI6Im9hdXRoX2hvc3QiLCJ2YWx1ZSI6InNhcGkudGVsc3RyYS5jb20iLCJ0eXBlIjoidGV4dCJ9LHsiZW5hYmxlZCI6dHJ1ZSwia2V5IjoibWVzc2FnZV9pZCIsInZhbHVlIjoiIiwidHlwZSI6InRleHQifV0=\"> <img src=\"https://run.pstmn.io/button.svg\" alt=\"Run in Postman\" /></a></p> <h2>Sample Apps</h2> - <a href=\"https://github.com/telstra/MessagingAPI-perl-sample-app\">Perl Sample App</a> - <a href=\"https://github.com/telstra/messaging-sample-code-happy-chat\">Happy Chat App</a> - <a href=\"https://github.com/developersteve/telstra-messaging-php\">PHP Sample App</a> <h2>SDK repos</h2> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-php\">Messaging API - PHP SDK</a> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-python\">Messaging API - Python SDK</a> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-ruby\">Messaging API - Ruby SDK</a> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-node\">Messaging API - NodeJS SDK</a> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-dotnet\">Messaging API - .Net2 SDK</a> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-Java\">Messaging API - Java SDK</a> # Delivery Notification The API provides several methods for notifying when a message has been delivered to the destination. <ol> <li>When you provision a number there is an opportunity to specify a <code>notifyURL</code>, when the message has been delivered the API will make a call to this URL to advise of the message status. If this is not provided then you can make use the Get Replies API to poll for messages.</li> <li>If you do not specify a URL you can always call the <code>GET /sms</code> API get the latest replies to the message.</li> </ol> <I>Please note that the notification URLs and the polling call are exclusive. If a notification URL has been set then the polling call will not provide any useful information.</I> <h2>Notification URL Format</h2> When a message has reached its final state, the API will send a POST to the URL that has been previously specified. <h3>Notification URL Format for SMS</h3> <pre><code class=\"language-sh\">{ to: '+61418123456' sentTimestamp: '2017-03-17T10:05:22+10:00' receivedTimestamp: '2017-03-17T10:05:23+10:00' messageId: /cccb284200035236000000000ee9d074019e0301/1261418123456 deliveryStatus: DELIVRD } </code></pre> \\ The fields are: <table> <thead> <tr> <th>Field</th> <th>Description</th> </tr> </thead> <tbody> <tr> <td><code>to</code></td> <td>The number the message was sent to.</td> </tr> <tr> <td><code>receivedTimestamp</code></td> <td>Time the message was sent to the API.</td> </tr> <tr> <td><code>sentTimestamp</code></td> <td>Time handling of the message ended.</td> </tr> <tr> <td><code>deliveryStatus</code></td> <td>The final state of the message.</td> </tr> <tr> <td><code>messageId</code></td> <td>The same reference that was returned when the original message was sent.</td> </tr> <tr> <td><code>receivedTimestamp</code></td> <td>Time the message was sent to the API.</td> </tr> </tbody> </table> Upon receiving this call it is expected that your servers will give a 204 (No Content) response. Anything else will cause the API to reattempt the call 5 minutes later. <h3>Notification URL Format for SMS Replies</h3> <pre><code class=\"language-sh\">{ \"status\": \"RECEIVED\" \"destinationAddress\": \"+61418123456\" \"senderAddress\": \"+61421987654\" \"message\": \"Foo\" \"sentTimestamp\": \"2018-03-23T12:10:06+10:00\" } </code></pre> \\ The fields are: <table> <thead> <tr> <th>Field</th> <th>Description</th> </tr> </thead> <tbody> <tr> <td><code>status</code></td> <td>The final state of the message.</td> </tr> <tr> <td><code>destinationAddress</code></td> <td>The number the message was sent to.</td> </tr> <tr> <td><code>senderAddress</code></td> <td>The number the message was sent from.</td> </tr> <tr> <td><code>message</code></td> <td>The sontent of the SMS reply.</td> </tr> <tr> <td><code>sentTimestamp</code></td> <td>Time handling of the message ended.</td> </tr> </tbody> </table> <h3>Notification URL Format for MMS Replies</h3> <pre><code class=\"language-sh\">{ \"status\": \"RECEIVED\", \"destinationAddress\": \"+61418123456\", \"senderAddress\": \"+61421987654\", \"subject\": \"Foo\", \"sentTimestamp\": \"2018-03-23T12:15:45+10:00\", \"envelope\": \"string\", \"MMSContent\": [ { \"type\": \"application/smil\", \"filename\": \"smil.xml\", \"payload\": \"string\" }, { \"type\": \"image/jpeg\", \"filename\": \"sample.jpeg\", \"payload\": \"string\" } ] } </code></pre> \\ The fields are: <table> <thead> <tr> <th>Field</th> <th>Description</th> </tr> </thead> <tbody> <tr> <td><code>status</code></td> <td>The final state of the message.</td> </tr> <tr> <td><code>destinationAddress</code></td> <td>The number the message was sent to.</td> </tr> <tr> <td><code>senderAddress</code></td> <td>The number the message was sent from.</td> </tr> <tr> <td><code>subject</code></td> <td>The subject assigned to the message.</td> </tr> <tr> <td><code>sentTimestamp</code></td> <td>Time handling of the message ended.</td> </tr> <tr> <td><code>envelope</code></td> <td>Information about about terminal type and originating operator.</td> </tr> <tr> <td><code>MMSContent</code></td> <td>An array of the actual content of the reply message.</td> </tr> <tr> <td><code>type</code></td> <td>The content type of the message.</td> </tr> <tr> <td><code>filename</code></td> <td>The filename for the message content.</td> </tr> <tr> <td><code>payload</code></td> <td>The content of the message.</td> </tr> </tbody> </table> # Frequently Asked Questions **Q: Can I send a broadcast message using the Telstra Messging API?** A. Yes. Recipient numbers can be in teh form of an array of strings if a broadcast message needs to be sent. <h2>Notes</h2> <a href=\"http://petstore.swagger.io/?url=https://raw.githubusercontent.com/telstra/MessagingAPI-v2/master/docs/swagger/messaging-api-swagger.yaml\" target=\"_blank\">View messaging in Swagger UI</a> # noqa: E5015 OpenAPI spec version: 2.2.66 7 Generated by: https://github.com/swagger-api/swagger-codegen.git8"""9from setuptools import setup, find_packages # noqa: H30110NAME = "swagger-client"11VERSION = "1.0.0"12# To install the library, run the following13#14# python setup.py install15#16# prerequisite: setuptools17# http://pypi.python.org/pypi/setuptools18REQUIRES = ["urllib3 >= 1.15", "six >= 1.10", "certifi", "python-dateutil"]19setup(20 name=NAME,21 version=VERSION,22 description="Telstra Messaging API",23 author_email="",24 url="",25 keywords=["Swagger", "Telstra Messaging API"],26 install_requires=REQUIRES,27 packages=find_packages(),28 include_package_data=True,29 long_description="""\30 # Introduction Send and receive SMS and MMS messages globally using Telstraââ¬â¢s enterprise grade Messaging API. It also allows your application to track the delivery status of both sent and received messages. Get your dedicated Australian number, and start sending and receiving messages today. # Features <p>The Telstra Messaging API provides the features below. <table> <thead> <tr> <th>Feature</th> <th>Description</th> </tr> </thead> <tbody> <tr> <td><code>Dedicated Number</code></td> <td>Provision a mobile number for your account to be used as from address in the API</td> </tr> <tr> <td><code>Send Messages</code></td> <td>Sending SMS or MMS messages</td> </tr> <tr> <td><code>Receive Messages</code></td> <td>Telstra will deliver messages sent to a dedicated number or to the <code>notifyURL</code> defined by you</td> </tr> <tr> <td><code>Broadcast Messages</code></td> <td>Invoke a single API to send a message to a list of number provided in <code>to</code></td> </tr> <tr> <td><code>Delivery Status</code></td> <td>Query the delivery status of your messages</td> </tr> <tr> <td><code>Callbacks</code></td> <td>Provide a notification URL and Telstra will notify your app when messages status changes</td> </tr> <tr> <td><code>Alphanumeric Identifier</code></td> <td>Differentiate yourself by providing an alphanumeric string in <code>from</code>. This feature is only available on paid plans</td> </tr> <tr> <td><code>Concatenation</code></td> <td>Send messages up to 1900 characters long and Telstra will automaticaly segment and reassemble them</td> </tr> <tr> <td><code>Reply Request</code></td> <td>Create a chat session by associating <code>messageId</code> and <code>to</code> number to track responses received from a mobile number. We will store this association for 8 days</td> </tr> <tr> <td><code>Character set</code></td> <td>Accepts all Unicode characters as part of UTF-8</td> </tr> <tr> <td><code>Bounce-back response</code></td> <td>See if your SMS hits an unreachable or unallocated number (Australia Only)</td> </tr> <tr> <td><code>Queuing</code></td> <td>Messaging API will automatically queue and deliver each message at a compliant rate.</td> </tr> <tr> <td><code>Emoji Encoding</code></td> <td>The API supports the encoding of the full range of emojis. Emojis in the reply messages will be in their UTF-8 format.</td> </tr> </tbody> </table> # Getting Access to the API <ol> <li>Register at <a href=\"https://dev.telstra.com\">https://dev.telstra.com</a>. <li>After registration, login to <a href=\"https://dev.telstra.com\">https://dev.telstra.com</a> and navigate to the &quot;My apps&quot; page. <li>Create your application by clicking the &quot;Add new app&quot; button <li>Select &quot;API Free Trial&quot; Product when configuring your application. This Product includes the Telstra Messaging API as well as other APIs. Your application will be approved automatically. <li>There is a maximum of 1000 free messages per developer. Additional messages and features can be purchased from <a href=\"https://dev.telstra.com\">https://dev.telstra.com</a>. <li>Note your <code>Client key</code> and <code>Client secret</code> as these will be needed to provision a number for your application and for authentication. </ol> <p>Now head over to <b>Getting Started</b> where you can find a postman collection as well as some links to sample apps and SDKs to get you started. <p>Happy Messaging! # Getting Started <p>Below are the steps to get started with the Telstra Messaging API.</p> <ol> <li>Generate OAuth2 Token using your <code>Client key</code> and <code>Client secret</code>.</li> <li>Create Subscription in order to receive a provisioned number.</li> <li>Send Message to a specific mobile number.</li> </ol> <h2>Run in Postman</h2> <p><a href=\"https://app.getpostman.com/run-collection/ded00578f69a9deba256#?env%5BMessaging%20API%20Environments%5D=W3siZW5hYmxlZCI6dHJ1ZSwia2V5IjoiY2xpZW50X2lkIiwidmFsdWUiOiIiLCJ0eXBlIjoidGV4dCJ9LHsiZW5hYmxlZCI6dHJ1ZSwia2V5IjoiY2xpZW50X3NlY3JldCIsInZhbHVlIjoiIiwidHlwZSI6InRleHQifSx7ImVuYWJsZWQiOnRydWUsImtleSI6ImFjY2Vzc190b2tlbiIsInZhbHVlIjoiIiwidHlwZSI6InRleHQifSx7ImVuYWJsZWQiOnRydWUsImtleSI6Imhvc3QiLCJ2YWx1ZSI6InRhcGkudGVsc3RyYS5jb20iLCJ0eXBlIjoidGV4dCJ9LHsiZW5hYmxlZCI6dHJ1ZSwia2V5IjoiQXV0aG9yaXphdGlvbiIsInZhbHVlIjoiIiwidHlwZSI6InRleHQifSx7ImVuYWJsZWQiOnRydWUsImtleSI6Im9hdXRoX2hvc3QiLCJ2YWx1ZSI6InNhcGkudGVsc3RyYS5jb20iLCJ0eXBlIjoidGV4dCJ9LHsiZW5hYmxlZCI6dHJ1ZSwia2V5IjoibWVzc2FnZV9pZCIsInZhbHVlIjoiIiwidHlwZSI6InRleHQifV0=\"> <img src=\"https://run.pstmn.io/button.svg\" alt=\"Run in Postman\" /></a></p> <h2>Sample Apps</h2> - <a href=\"https://github.com/telstra/MessagingAPI-perl-sample-app\">Perl Sample App</a> - <a href=\"https://github.com/telstra/messaging-sample-code-happy-chat\">Happy Chat App</a> - <a href=\"https://github.com/developersteve/telstra-messaging-php\">PHP Sample App</a> <h2>SDK repos</h2> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-php\">Messaging API - PHP SDK</a> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-python\">Messaging API - Python SDK</a> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-ruby\">Messaging API - Ruby SDK</a> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-node\">Messaging API - NodeJS SDK</a> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-dotnet\">Messaging API - .Net2 SDK</a> - <a href=\"https://github.com/telstra/MessagingAPI-SDK-Java\">Messaging API - Java SDK</a> # Delivery Notification The API provides several methods for notifying when a message has been delivered to the destination. <ol> <li>When you provision a number there is an opportunity to specify a <code>notifyURL</code>, when the message has been delivered the API will make a call to this URL to advise of the message status. If this is not provided then you can make use the Get Replies API to poll for messages.</li> <li>If you do not specify a URL you can always call the <code>GET /sms</code> API get the latest replies to the message.</li> </ol> <I>Please note that the notification URLs and the polling call are exclusive. If a notification URL has been set then the polling call will not provide any useful information.</I> <h2>Notification URL Format</h2> When a message has reached its final state, the API will send a POST to the URL that has been previously specified. <h3>Notification URL Format for SMS</h3> <pre><code class=\"language-sh\">{ to: '+61418123456' sentTimestamp: '2017-03-17T10:05:22+10:00' receivedTimestamp: '2017-03-17T10:05:23+10:00' messageId: /cccb284200035236000000000ee9d074019e0301/1261418123456 deliveryStatus: DELIVRD } </code></pre> \\ The fields are: <table> <thead> <tr> <th>Field</th> <th>Description</th> </tr> </thead> <tbody> <tr> <td><code>to</code></td> <td>The number the message was sent to.</td> </tr> <tr> <td><code>receivedTimestamp</code></td> <td>Time the message was sent to the API.</td> </tr> <tr> <td><code>sentTimestamp</code></td> <td>Time handling of the message ended.</td> </tr> <tr> <td><code>deliveryStatus</code></td> <td>The final state of the message.</td> </tr> <tr> <td><code>messageId</code></td> <td>The same reference that was returned when the original message was sent.</td> </tr> <tr> <td><code>receivedTimestamp</code></td> <td>Time the message was sent to the API.</td> </tr> </tbody> </table> Upon receiving this call it is expected that your servers will give a 204 (No Content) response. Anything else will cause the API to reattempt the call 5 minutes later. <h3>Notification URL Format for SMS Replies</h3> <pre><code class=\"language-sh\">{ \"status\": \"RECEIVED\" \"destinationAddress\": \"+61418123456\" \"senderAddress\": \"+61421987654\" \"message\": \"Foo\" \"sentTimestamp\": \"2018-03-23T12:10:06+10:00\" } </code></pre> \\ The fields are: <table> <thead> <tr> <th>Field</th> <th>Description</th> </tr> </thead> <tbody> <tr> <td><code>status</code></td> <td>The final state of the message.</td> </tr> <tr> <td><code>destinationAddress</code></td> <td>The number the message was sent to.</td> </tr> <tr> <td><code>senderAddress</code></td> <td>The number the message was sent from.</td> </tr> <tr> <td><code>message</code></td> <td>The sontent of the SMS reply.</td> </tr> <tr> <td><code>sentTimestamp</code></td> <td>Time handling of the message ended.</td> </tr> </tbody> </table> <h3>Notification URL Format for MMS Replies</h3> <pre><code class=\"language-sh\">{ \"status\": \"RECEIVED\", \"destinationAddress\": \"+61418123456\", \"senderAddress\": \"+61421987654\", \"subject\": \"Foo\", \"sentTimestamp\": \"2018-03-23T12:15:45+10:00\", \"envelope\": \"string\", \"MMSContent\": [ { \"type\": \"application/smil\", \"filename\": \"smil.xml\", \"payload\": \"string\" }, { \"type\": \"image/jpeg\", \"filename\": \"sample.jpeg\", \"payload\": \"string\" } ] } </code></pre> \\ The fields are: <table> <thead> <tr> <th>Field</th> <th>Description</th> </tr> </thead> <tbody> <tr> <td><code>status</code></td> <td>The final state of the message.</td> </tr> <tr> <td><code>destinationAddress</code></td> <td>The number the message was sent to.</td> </tr> <tr> <td><code>senderAddress</code></td> <td>The number the message was sent from.</td> </tr> <tr> <td><code>subject</code></td> <td>The subject assigned to the message.</td> </tr> <tr> <td><code>sentTimestamp</code></td> <td>Time handling of the message ended.</td> </tr> <tr> <td><code>envelope</code></td> <td>Information about about terminal type and originating operator.</td> </tr> <tr> <td><code>MMSContent</code></td> <td>An array of the actual content of the reply message.</td> </tr> <tr> <td><code>type</code></td> <td>The content type of the message.</td> </tr> <tr> <td><code>filename</code></td> <td>The filename for the message content.</td> </tr> <tr> <td><code>payload</code></td> <td>The content of the message.</td> </tr> </tbody> </table> # Frequently Asked Questions **Q: Can I send a broadcast message using the Telstra Messging API?** A. Yes. Recipient numbers can be in teh form of an array of strings if a broadcast message needs to be sent. <h2>Notes</h2> <a href=\"http://petstore.swagger.io/?url=https://raw.githubusercontent.com/telstra/MessagingAPI-v2/master/docs/swagger/messaging-api-swagger.yaml\" target=\"_blank\">View messaging in Swagger UI</a> # noqa: E50131 """...
cpso.py
Source:cpso.py
1import requests2from bs4 import BeautifulSoup3import pandas as pd4from string import ascii_lowercase5import re6import time7from random import randint8import csv9import sys10import os11url_search = "https://www.cpso.on.ca/Public-Information-Services/Find-a-Doctor?search=general"12url_paging = "https://www.cpso.on.ca/Public-Register-Info-(1)/Doctor-Search-Results"13abs_file_path = os.path.abspath(__file__)14file_dir = os.path.dirname(os.path.abspath(__file__))15project_dir = os.path.dirname(file_dir)16# progess bar17def progress(count, total, status=''):18 bar_len = 6019 filled_len = int(round(bar_len * count / float(total)))20 percents = round(100.0 * count / float(total), 1)21 bar = '=' * filled_len + '-' * (bar_len - filled_len)22 sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', status[0:40].ljust(40)))23 sys.stdout.flush()24def scrape_doctors(soup):25 new_docs = {}26 for doctor in soup.find_all(class_ = 'doctor-search-results--result'):27 doc_str = doctor.find('h3').text28 new_docs[ (re.search("\d+", doc_str).group()) ] = doctor29 return new_docs30headers_search = {31 "Host": "www.cpso.on.ca",32 "User-Agent": "Adam Kowalczewski, Service Canada, adam.kowalczewski@servicecanada.gc.ca",33 "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",34 "Accept-Language": "en-US,en;q=0.5",35 "Accept-Encoding": "gzip, deflate, br",36 "Referer": "https://www.cpso.on.ca/Public-Information-Services/Find-a-Doctor?search=general",37 "Content-Type": "application/x-www-form-urlencoded",38 "Connection": "keep-alive",39 "Cookie": "CMSPreferredCulture=en-CA; _ga=GA1.3.788028045.1540596057; _gid=GA1.3.1509832542.1540596057; CMSCsrfCookie=PB+DjsX3FE9SCgAk3jGZzc65Ld6NEaE755HpaDB9; ASP.NET_SessionId=pek14izhwnw4u0curdgkxygh; _gat_UA-36725164-1=1",40 "Upgrade-Insecure-Requests": "1"41}42headers_paging = {43 "Referer": "https://www.cpso.on.ca/Public-Register-Info-(1)/Doctor-Search-Results"44}45manScript = "%3b%3bAjaxControlToolkit%2c+Version%3d4.1.60919.0%2c+Culture%3dneutral%2c+PublicKeyToken%3d28f01b0e84b6d53e%3aen-CA%3aee051b62-9cd6-49a5-87bb-93c07bc43d63%3a475a4ef5%3aeffe2a26%3a7e63a579"46manScript = manScript.replace( "%3b", ";" )47manScript = manScript.replace( "%2c", "," )48manScript = manScript.replace( "%3d", "=" )49manScript = manScript.replace( "%3a", ":" )50manScript = manScript.replace( "en-CA", "en-US")51def crawl_cpso( city_code = '', fsa = '', char = '' ):52 doctors = {}53 # get HTML to generate POST data54 with requests.Session() as s:55 r = s.get(url_search)56 soup = BeautifulSoup(r.content, 'html.parser')57 payload = {58 "__CMSCsrfToken": soup.find("input", id="__CMSCsrfToken")['value'],59 "__VIEWSTATE": soup.find("input", id="__VIEWSTATE")['value'],60 "__VIEWSTATEGENERATOR": soup.find("input", id="__VIEWSTATEGENERATOR")['value'],61 "lng": 'en-CA',62 "manScript_HiddenField": manScript,63 "searchType":"general",64 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$advancedState":"closed",65 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$ConcernsState":"closed",66 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$ddCity": city_code,67 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$txtPostalCode": fsa,68 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$txtLastName": char,69 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$grpGender":"+",70 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$grpDocType":"rdoDocTypeAll",71 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$ddHospitalName":"-1",72 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$ddLanguage":"08",73 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkActiveDoctors":"on",74 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkPracticeRestrictions":"on",75 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkPendingHearings":"on",76 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkPastHearings":"on",77 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkHospitalNotices":"on",78 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkConcerns":"on",79 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkNotices":"on",80 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$btnSubmit1":"Submit"81 }82 # send POST request83 time.sleep(randint(1,3))84 r = s.post( url_search, data = payload )85 soup = BeautifulSoup(r.content, 'html.parser')86 doctors.update( scrape_doctors(soup) )87 while True:88 # stop if there's no pages89 try:90 n_in_group = len(soup.find('div', class_ = "doctor-search-paging").find_all('a', id = re.compile("rptPages")))91 except:92 break93 # page through the rest of the group94 for i in range(1, n_in_group):95 payload_paging = {96 "__CMSCsrfToken": soup.find("input", id="__CMSCsrfToken")['value'],97 "__VIEWSTATE": soup.find("input", id="__VIEWSTATE")['value'],98 "__VIEWSTATEGENERATOR": soup.find("input", id="__VIEWSTATEGENERATOR")['value'],99 "__EVENTTARGET": "p$lt$ctl04$pageplaceholder$p$lt$ctl03$CPSO_DoctorSearchResults$rptPages$ctl0{:1}$lnbPage",100 "lng": 'en-CA',101 "manScript_HiddenField": manScript,102 "p$lt$ctl04$pageplaceholder$p$lt$ctl03$CPSO_DoctorSearchResults$hdnCurrentPage": 1103 }104 payload_paging['__EVENTTARGET'] = payload_paging['__EVENTTARGET'].format(i)105 time.sleep(randint(1,3))106 r = s.post( url_paging, headers = headers_search.update(headers_paging), data = payload_paging )107 soup = BeautifulSoup(r.content, 'html.parser')108 #page_num = soup.find('div', class_ = 'doctor-search-count').find('div', class_ = 'text-align--right').text.strip()109 doctors.update( scrape_doctors(soup) )110 payload_paging['p$lt$ctl04$pageplaceholder$p$lt$ctl03$CPSO_DoctorSearchResults$hdnCurrentPage'] += 1111 # stop if there's no more groups112 if soup.find(class_ = "aspNetDisabled next") != None:113 break114 # switch to the next group115 payload_paging['__EVENTTARGET'] = 'p$lt$ctl04$pageplaceholder$p$lt$ctl03$CPSO_DoctorSearchResults$lnbNextGroup'116 time.sleep(randint(1,3))117 r = s.post( url_paging, data = payload_paging )118 soup = BeautifulSoup(r.content, 'html.parser')119 return doctors120def count_doctors( city_code = '', fsa = '', char = '' ):121 with requests.Session() as s:122 r = s.get(url_search)123 soup = BeautifulSoup(r.content, 'html.parser')124 payload = {125 "__CMSCsrfToken": soup.find("input", id="__CMSCsrfToken")['value'],126 "__VIEWSTATE": soup.find("input", id="__VIEWSTATE")['value'],127 "__VIEWSTATEGENERATOR": soup.find("input", id="__VIEWSTATEGENERATOR")['value'],128 "lng": 'en-CA',129 "manScript_HiddenField": manScript,130 "searchType":"general",131 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$advancedState":"closed",132 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$ConcernsState":"closed",133 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$ddCity": city_code,134 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$txtPostalCode": fsa,135 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$txtLastName": char,136 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$grpGender":"+",137 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$grpDocType":"rdoDocTypeAll",138 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$ddHospitalName":"-1",139 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$ddLanguage":"08",140 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkActiveDoctors":"on",141 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkPracticeRestrictions":"on",142 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkPendingHearings":"on",143 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkPastHearings":"on",144 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkHospitalNotices":"on",145 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkConcerns":"on",146 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$chkNotices":"on",147 "p$lt$ctl04$pageplaceholder$p$lt$ctl02$AllDoctorsSearch$btnSubmit1":"Submit"148 }149 # get doctor search result count150 time.sleep(randint(1,3))151 r = s.post(url_search, data = payload)152 soup = BeautifulSoup(r.content, 'html.parser')153 target = soup.find('div', class_ = "doctor-search-count").strong.text...
Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!