Best Python code snippet using pytest
basic_session_run_hooks_test.py
Source:basic_session_run_hooks_test.py
1# pylint: disable=g-bad-file-header2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.3#4# Licensed under the Apache License, Version 2.0 (the "License");5# you may not use this file except in compliance with the License.6# You may obtain a copy of the License at7#8# http://www.apache.org/licenses/LICENSE-2.09#10# Unless required by applicable law or agreed to in writing, software11# distributed under the License is distributed on an "AS IS" BASIS,12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.13# See the License for the specific language governing permissions and14# limitations under the License.15# ==============================================================================16"""Tests for basic_session_run_hooks."""17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20import shutil21import tempfile22import threading23import time24from tensorflow.contrib.framework.python.framework import checkpoint_utils25from tensorflow.contrib.framework.python.ops import variables26from tensorflow.contrib.testing.python.framework import fake_summary_writer27from tensorflow.python.client import session as session_lib28from tensorflow.python.framework import constant_op29from tensorflow.python.framework import dtypes30from tensorflow.python.framework import meta_graph31from tensorflow.python.framework import ops32from tensorflow.python.ops import array_ops33from tensorflow.python.ops import control_flow_ops34from tensorflow.python.ops import state_ops35from tensorflow.python.ops import variable_scope36from tensorflow.python.ops import variables as variables_lib37import tensorflow.python.ops.nn_grad # pylint: disable=unused-import38from tensorflow.python.platform import test39from tensorflow.python.platform import tf_logging40from tensorflow.python.summary import summary as summary_lib41from tensorflow.python.summary.writer import writer_cache42from tensorflow.python.training import basic_session_run_hooks43from tensorflow.python.training import monitored_session44from tensorflow.python.training import session_run_hook45class MockCheckpointSaverListener(46 basic_session_run_hooks.CheckpointSaverListener):47 def __init__(self):48 self.begin_count = 049 self.before_save_count = 050 self.after_save_count = 051 self.end_count = 052 def begin(self):53 self.begin_count += 154 def before_save(self, session, global_step):55 self.before_save_count += 156 def after_save(self, session, global_step):57 self.after_save_count += 158 def end(self, session, global_step):59 self.end_count += 160 def get_counts(self):61 return {62 'begin': self.begin_count,63 'before_save': self.before_save_count,64 'after_save': self.after_save_count,65 'end': self.end_count66 }67class SecondOrStepTimerTest(test.TestCase):68 def test_raise_in_both_secs_and_steps(self):69 with self.assertRaises(ValueError):70 basic_session_run_hooks.SecondOrStepTimer(every_secs=2.0, every_steps=10)71 def test_raise_in_none_secs_and_steps(self):72 with self.assertRaises(ValueError):73 basic_session_run_hooks.SecondOrStepTimer()74 def test_every_secs(self):75 timer = basic_session_run_hooks.SecondOrStepTimer(every_secs=1.0)76 self.assertTrue(timer.should_trigger_for_step(1))77 timer.update_last_triggered_step(1)78 self.assertFalse(timer.should_trigger_for_step(1))79 self.assertFalse(timer.should_trigger_for_step(2))80 time.sleep(1.0)81 self.assertFalse(timer.should_trigger_for_step(1))82 self.assertTrue(timer.should_trigger_for_step(2))83 def test_every_steps(self):84 timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=3)85 self.assertTrue(timer.should_trigger_for_step(1))86 timer.update_last_triggered_step(1)87 self.assertFalse(timer.should_trigger_for_step(1))88 self.assertFalse(timer.should_trigger_for_step(2))89 self.assertFalse(timer.should_trigger_for_step(3))90 self.assertTrue(timer.should_trigger_for_step(4))91 def test_update_last_triggered_step(self):92 timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=1)93 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(1)94 self.assertEqual(None, elapsed_secs)95 self.assertEqual(None, elapsed_steps)96 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(5)97 self.assertLess(0, elapsed_secs)98 self.assertEqual(4, elapsed_steps)99 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(7)100 self.assertLess(0, elapsed_secs)101 self.assertEqual(2, elapsed_steps)102class StopAtStepTest(test.TestCase):103 def test_raise_in_both_last_step_and_num_steps(self):104 with self.assertRaises(ValueError):105 basic_session_run_hooks.StopAtStepHook(num_steps=10, last_step=20)106 def test_stop_based_on_last_step(self):107 h = basic_session_run_hooks.StopAtStepHook(last_step=10)108 with ops.Graph().as_default():109 global_step = variables.get_or_create_global_step()110 no_op = control_flow_ops.no_op()111 h.begin()112 with session_lib.Session() as sess:113 mon_sess = monitored_session._HookedSession(sess, [h])114 sess.run(state_ops.assign(global_step, 5))115 h.after_create_session(sess, None)116 mon_sess.run(no_op)117 self.assertFalse(mon_sess.should_stop())118 sess.run(state_ops.assign(global_step, 9))119 mon_sess.run(no_op)120 self.assertFalse(mon_sess.should_stop())121 sess.run(state_ops.assign(global_step, 10))122 mon_sess.run(no_op)123 self.assertTrue(mon_sess.should_stop())124 sess.run(state_ops.assign(global_step, 11))125 mon_sess._should_stop = False126 mon_sess.run(no_op)127 self.assertTrue(mon_sess.should_stop())128 def test_stop_based_on_num_step(self):129 h = basic_session_run_hooks.StopAtStepHook(num_steps=10)130 with ops.Graph().as_default():131 global_step = variables.get_or_create_global_step()132 no_op = control_flow_ops.no_op()133 h.begin()134 with session_lib.Session() as sess:135 mon_sess = monitored_session._HookedSession(sess, [h])136 sess.run(state_ops.assign(global_step, 5))137 h.after_create_session(sess, None)138 mon_sess.run(no_op)139 self.assertFalse(mon_sess.should_stop())140 sess.run(state_ops.assign(global_step, 13))141 mon_sess.run(no_op)142 self.assertFalse(mon_sess.should_stop())143 sess.run(state_ops.assign(global_step, 14))144 mon_sess.run(no_op)145 self.assertFalse(mon_sess.should_stop())146 sess.run(state_ops.assign(global_step, 15))147 mon_sess.run(no_op)148 self.assertTrue(mon_sess.should_stop())149 sess.run(state_ops.assign(global_step, 16))150 mon_sess._should_stop = False151 mon_sess.run(no_op)152 self.assertTrue(mon_sess.should_stop())153 def test_stop_based_with_multiple_steps(self):154 h = basic_session_run_hooks.StopAtStepHook(num_steps=10)155 with ops.Graph().as_default():156 global_step = variables.get_or_create_global_step()157 no_op = control_flow_ops.no_op()158 h.begin()159 with session_lib.Session() as sess:160 mon_sess = monitored_session._HookedSession(sess, [h])161 sess.run(state_ops.assign(global_step, 5))162 h.after_create_session(sess, None)163 mon_sess.run(no_op)164 self.assertFalse(mon_sess.should_stop())165 sess.run(state_ops.assign(global_step, 15))166 mon_sess.run(no_op)167 self.assertTrue(mon_sess.should_stop())168class LoggingTensorHookTest(test.TestCase):169 def setUp(self):170 # Mock out logging calls so we can verify whether correct tensors are being171 # monitored.172 self._actual_log = tf_logging.info173 self.logged_message = None174 def mock_log(*args, **kwargs):175 self.logged_message = args176 self._actual_log(*args, **kwargs)177 tf_logging.info = mock_log178 def tearDown(self):179 tf_logging.info = self._actual_log180 def test_illegal_args(self):181 with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):182 basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=0)183 with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):184 basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=-10)185 with self.assertRaisesRegexp(ValueError, 'xactly one of'):186 basic_session_run_hooks.LoggingTensorHook(187 tensors=['t'], every_n_iter=5, every_n_secs=5)188 with self.assertRaisesRegexp(ValueError, 'xactly one of'):189 basic_session_run_hooks.LoggingTensorHook(tensors=['t'])190 def test_print_at_end_only(self):191 with ops.Graph().as_default(), session_lib.Session() as sess:192 t = constant_op.constant(42.0, name='foo')193 train_op = constant_op.constant(3)194 hook = basic_session_run_hooks.LoggingTensorHook(195 tensors=[t.name], at_end=True)196 hook.begin()197 mon_sess = monitored_session._HookedSession(sess, [hook])198 sess.run(variables_lib.global_variables_initializer())199 self.logged_message = ''200 for _ in range(3):201 mon_sess.run(train_op)202 # assertNotRegexpMatches is not supported by python 3.1 and later203 self.assertEqual(str(self.logged_message).find(t.name), -1)204 hook.end(sess)205 self.assertRegexpMatches(str(self.logged_message), t.name)206 def _validate_print_every_n_steps(self, sess, at_end):207 t = constant_op.constant(42.0, name='foo')208 train_op = constant_op.constant(3)209 hook = basic_session_run_hooks.LoggingTensorHook(210 tensors=[t.name], every_n_iter=10, at_end=at_end)211 hook.begin()212 mon_sess = monitored_session._HookedSession(sess, [hook])213 sess.run(variables_lib.global_variables_initializer())214 mon_sess.run(train_op)215 self.assertRegexpMatches(str(self.logged_message), t.name)216 for _ in range(3):217 self.logged_message = ''218 for _ in range(9):219 mon_sess.run(train_op)220 # assertNotRegexpMatches is not supported by python 3.1 and later221 self.assertEqual(str(self.logged_message).find(t.name), -1)222 mon_sess.run(train_op)223 self.assertRegexpMatches(str(self.logged_message), t.name)224 # Add additional run to verify proper reset when called multiple times.225 self.logged_message = ''226 mon_sess.run(train_op)227 # assertNotRegexpMatches is not supported by python 3.1 and later228 self.assertEqual(str(self.logged_message).find(t.name), -1)229 self.logged_message = ''230 hook.end(sess)231 if at_end:232 self.assertRegexpMatches(str(self.logged_message), t.name)233 else:234 # assertNotRegexpMatches is not supported by python 3.1 and later235 self.assertEqual(str(self.logged_message).find(t.name), -1)236 def test_print_every_n_steps(self):237 with ops.Graph().as_default(), session_lib.Session() as sess:238 self._validate_print_every_n_steps(sess, at_end=False)239 # Verify proper reset.240 self._validate_print_every_n_steps(sess, at_end=False)241 def test_print_every_n_steps_and_end(self):242 with ops.Graph().as_default(), session_lib.Session() as sess:243 self._validate_print_every_n_steps(sess, at_end=True)244 # Verify proper reset.245 self._validate_print_every_n_steps(sess, at_end=True)246 def test_print_first_step(self):247 # if it runs every iteration, first iteration has None duration.248 with ops.Graph().as_default(), session_lib.Session() as sess:249 t = constant_op.constant(42.0, name='foo')250 train_op = constant_op.constant(3)251 hook = basic_session_run_hooks.LoggingTensorHook(252 tensors={'foo': t}, every_n_iter=1)253 hook.begin()254 mon_sess = monitored_session._HookedSession(sess, [hook])255 sess.run(variables_lib.global_variables_initializer())256 mon_sess.run(train_op)257 self.assertRegexpMatches(str(self.logged_message), 'foo')258 # in first run, elapsed time is None.259 self.assertEqual(str(self.logged_message).find('sec'), -1)260 def _validate_print_every_n_secs(self, sess, at_end):261 t = constant_op.constant(42.0, name='foo')262 train_op = constant_op.constant(3)263 hook = basic_session_run_hooks.LoggingTensorHook(264 tensors=[t.name], every_n_secs=1.0, at_end=at_end)265 hook.begin()266 mon_sess = monitored_session._HookedSession(sess, [hook])267 sess.run(variables_lib.global_variables_initializer())268 mon_sess.run(train_op)269 self.assertRegexpMatches(str(self.logged_message), t.name)270 # assertNotRegexpMatches is not supported by python 3.1 and later271 self.logged_message = ''272 mon_sess.run(train_op)273 self.assertEqual(str(self.logged_message).find(t.name), -1)274 time.sleep(1.0)275 self.logged_message = ''276 mon_sess.run(train_op)277 self.assertRegexpMatches(str(self.logged_message), t.name)278 self.logged_message = ''279 hook.end(sess)280 if at_end:281 self.assertRegexpMatches(str(self.logged_message), t.name)282 else:283 # assertNotRegexpMatches is not supported by python 3.1 and later284 self.assertEqual(str(self.logged_message).find(t.name), -1)285 def test_print_every_n_secs(self):286 with ops.Graph().as_default(), session_lib.Session() as sess:287 self._validate_print_every_n_secs(sess, at_end=False)288 # Verify proper reset.289 self._validate_print_every_n_secs(sess, at_end=False)290 def test_print_every_n_secs_and_end(self):291 with ops.Graph().as_default(), session_lib.Session() as sess:292 self._validate_print_every_n_secs(sess, at_end=True)293 # Verify proper reset.294 self._validate_print_every_n_secs(sess, at_end=True)295 def test_print_formatter(self):296 with ops.Graph().as_default(), session_lib.Session() as sess:297 t = constant_op.constant(42.0, name='foo')298 train_op = constant_op.constant(3)299 hook = basic_session_run_hooks.LoggingTensorHook(300 tensors=[t.name], every_n_iter=10,301 formatter=lambda items: 'qqq=%s' % items[t.name])302 hook.begin()303 mon_sess = monitored_session._HookedSession(sess, [hook])304 sess.run(variables_lib.global_variables_initializer())305 mon_sess.run(train_op)306 self.assertEqual(self.logged_message[0], 'qqq=42.0')307class CheckpointSaverHookTest(test.TestCase):308 def setUp(self):309 self.model_dir = tempfile.mkdtemp()310 self.graph = ops.Graph()311 with self.graph.as_default():312 self.scaffold = monitored_session.Scaffold()313 self.global_step = variables.get_or_create_global_step()314 self.train_op = state_ops.assign_add(self.global_step, 1)315 def tearDown(self):316 shutil.rmtree(self.model_dir, ignore_errors=True)317 def test_raise_when_saver_and_scaffold_both_missing(self):318 with self.assertRaises(ValueError):319 basic_session_run_hooks.CheckpointSaverHook(self.model_dir)320 def test_raise_when_saver_and_scaffold_both_present(self):321 with self.assertRaises(ValueError):322 basic_session_run_hooks.CheckpointSaverHook(323 self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold)324 def test_raise_in_both_secs_and_steps(self):325 with self.assertRaises(ValueError):326 basic_session_run_hooks.CheckpointSaverHook(327 self.model_dir, save_secs=10, save_steps=20)328 def test_raise_in_none_secs_and_steps(self):329 with self.assertRaises(ValueError):330 basic_session_run_hooks.CheckpointSaverHook(self.model_dir)331 def test_save_secs_saves_in_first_step(self):332 with self.graph.as_default():333 hook = basic_session_run_hooks.CheckpointSaverHook(334 self.model_dir, save_secs=2, scaffold=self.scaffold)335 hook.begin()336 self.scaffold.finalize()337 with session_lib.Session() as sess:338 sess.run(self.scaffold.init_op)339 mon_sess = monitored_session._HookedSession(sess, [hook])340 mon_sess.run(self.train_op)341 self.assertEqual(1,342 checkpoint_utils.load_variable(self.model_dir,343 self.global_step.name))344 def test_save_secs_calls_listeners_at_begin_and_end(self):345 with self.graph.as_default():346 listener = MockCheckpointSaverListener()347 hook = basic_session_run_hooks.CheckpointSaverHook(348 self.model_dir,349 save_secs=2,350 scaffold=self.scaffold,351 listeners=[listener])352 hook.begin()353 self.scaffold.finalize()354 with session_lib.Session() as sess:355 sess.run(self.scaffold.init_op)356 mon_sess = monitored_session._HookedSession(sess, [hook])357 mon_sess.run(self.train_op) # hook runs here358 mon_sess.run(self.train_op) # hook won't run here, so it does at end359 hook.end(sess) # hook runs here360 self.assertEqual({361 'begin': 1,362 'before_save': 2,363 'after_save': 2,364 'end': 1365 }, listener.get_counts())366 def test_listener_with_monitored_session(self):367 with ops.Graph().as_default():368 scaffold = monitored_session.Scaffold()369 global_step = variables.get_or_create_global_step()370 train_op = state_ops.assign_add(global_step, 1)371 listener = MockCheckpointSaverListener()372 hook = basic_session_run_hooks.CheckpointSaverHook(373 self.model_dir,374 save_steps=1,375 scaffold=scaffold,376 listeners=[listener])377 with monitored_session.SingularMonitoredSession(378 hooks=[hook],379 scaffold=scaffold,380 checkpoint_dir=self.model_dir) as sess:381 sess.run(train_op)382 sess.run(train_op)383 global_step_val = sess.run(global_step)384 listener_counts = listener.get_counts()385 self.assertEqual(2, global_step_val)386 self.assertEqual({387 'begin': 1,388 'before_save': 2,389 'after_save': 2,390 'end': 1391 }, listener_counts)392 def test_listener_with_default_saver(self):393 with ops.Graph().as_default():394 global_step = variables.get_or_create_global_step()395 train_op = state_ops.assign_add(global_step, 1)396 listener = MockCheckpointSaverListener()397 hook = basic_session_run_hooks.CheckpointSaverHook(398 self.model_dir,399 save_steps=1,400 listeners=[listener])401 with monitored_session.SingularMonitoredSession(402 hooks=[hook],403 checkpoint_dir=self.model_dir) as sess:404 sess.run(train_op)405 sess.run(train_op)406 global_step_val = sess.run(global_step)407 listener_counts = listener.get_counts()408 self.assertEqual(2, global_step_val)409 self.assertEqual({410 'begin': 1,411 'before_save': 2,412 'after_save': 2,413 'end': 1414 }, listener_counts)415 with ops.Graph().as_default():416 global_step = variables.get_or_create_global_step()417 with monitored_session.SingularMonitoredSession(418 checkpoint_dir=self.model_dir) as sess2:419 global_step_saved_val = sess2.run(global_step)420 self.assertEqual(2, global_step_saved_val)421 def test_two_listeners_with_default_saver(self):422 with ops.Graph().as_default():423 global_step = variables.get_or_create_global_step()424 train_op = state_ops.assign_add(global_step, 1)425 listener1 = MockCheckpointSaverListener()426 listener2 = MockCheckpointSaverListener()427 hook = basic_session_run_hooks.CheckpointSaverHook(428 self.model_dir,429 save_steps=1,430 listeners=[listener1, listener2])431 with monitored_session.SingularMonitoredSession(432 hooks=[hook],433 checkpoint_dir=self.model_dir) as sess:434 sess.run(train_op)435 sess.run(train_op)436 global_step_val = sess.run(global_step)437 listener1_counts = listener1.get_counts()438 listener2_counts = listener2.get_counts()439 self.assertEqual(2, global_step_val)440 self.assertEqual({441 'begin': 1,442 'before_save': 2,443 'after_save': 2,444 'end': 1445 }, listener1_counts)446 self.assertEqual(listener1_counts, listener2_counts)447 with ops.Graph().as_default():448 global_step = variables.get_or_create_global_step()449 with monitored_session.SingularMonitoredSession(450 checkpoint_dir=self.model_dir) as sess2:451 global_step_saved_val = sess2.run(global_step)452 self.assertEqual(2, global_step_saved_val)453 @test.mock.patch('time.time')454 def test_save_secs_saves_periodically(self, mock_time):455 # Let's have a realistic start time456 current_time = 1484695987.209386457 with self.graph.as_default():458 mock_time.return_value = current_time459 hook = basic_session_run_hooks.CheckpointSaverHook(460 self.model_dir, save_secs=2, scaffold=self.scaffold)461 hook.begin()462 self.scaffold.finalize()463 with session_lib.Session() as sess:464 sess.run(self.scaffold.init_op)465 mon_sess = monitored_session._HookedSession(sess, [hook])466 mock_time.return_value = current_time467 mon_sess.run(self.train_op) # Saved.468 mock_time.return_value = current_time + 0.5469 mon_sess.run(self.train_op) # Not saved.470 self.assertEqual(1,471 checkpoint_utils.load_variable(self.model_dir,472 self.global_step.name))473 # Simulate 2.5 seconds of sleep.474 mock_time.return_value = current_time + 2.5475 mon_sess.run(self.train_op) # Saved.476 mock_time.return_value = current_time + 2.6477 mon_sess.run(self.train_op) # Not saved.478 mock_time.return_value = current_time + 2.7479 mon_sess.run(self.train_op) # Not saved.480 self.assertEqual(3,481 checkpoint_utils.load_variable(self.model_dir,482 self.global_step.name))483 # Simulate 7.5 more seconds of sleep (10 seconds from start.484 mock_time.return_value = current_time + 10485 mon_sess.run(self.train_op) # Saved.486 self.assertEqual(6,487 checkpoint_utils.load_variable(self.model_dir,488 self.global_step.name))489 # Flaky because of time.sleep()490 def DISABLED_test_save_secs_calls_listeners_periodically(self):491 with self.graph.as_default():492 listener = MockCheckpointSaverListener()493 hook = basic_session_run_hooks.CheckpointSaverHook(494 self.model_dir,495 save_secs=2,496 scaffold=self.scaffold,497 listeners=[listener])498 hook.begin()499 self.scaffold.finalize()500 with session_lib.Session() as sess:501 sess.run(self.scaffold.init_op)502 mon_sess = monitored_session._HookedSession(sess, [hook])503 mon_sess.run(self.train_op) # hook runs here504 mon_sess.run(self.train_op)505 time.sleep(2.5)506 mon_sess.run(self.train_op) # hook runs here507 mon_sess.run(self.train_op)508 mon_sess.run(self.train_op)509 time.sleep(2.5)510 mon_sess.run(self.train_op) # hook runs here511 mon_sess.run(self.train_op) # hook won't run here, so it does at end512 hook.end(sess) # hook runs here513 self.assertEqual({514 'begin': 1,515 'before_save': 4,516 'after_save': 4,517 'end': 1518 }, listener.get_counts())519 def test_save_steps_saves_in_first_step(self):520 with self.graph.as_default():521 hook = basic_session_run_hooks.CheckpointSaverHook(522 self.model_dir, save_steps=2, scaffold=self.scaffold)523 hook.begin()524 self.scaffold.finalize()525 with session_lib.Session() as sess:526 sess.run(self.scaffold.init_op)527 mon_sess = monitored_session._HookedSession(sess, [hook])528 mon_sess.run(self.train_op)529 self.assertEqual(1,530 checkpoint_utils.load_variable(self.model_dir,531 self.global_step.name))532 def test_save_steps_saves_periodically(self):533 with self.graph.as_default():534 hook = basic_session_run_hooks.CheckpointSaverHook(535 self.model_dir, save_steps=2, scaffold=self.scaffold)536 hook.begin()537 self.scaffold.finalize()538 with session_lib.Session() as sess:539 sess.run(self.scaffold.init_op)540 mon_sess = monitored_session._HookedSession(sess, [hook])541 mon_sess.run(self.train_op)542 mon_sess.run(self.train_op)543 # Not saved544 self.assertEqual(1,545 checkpoint_utils.load_variable(self.model_dir,546 self.global_step.name))547 mon_sess.run(self.train_op)548 # saved549 self.assertEqual(3,550 checkpoint_utils.load_variable(self.model_dir,551 self.global_step.name))552 mon_sess.run(self.train_op)553 # Not saved554 self.assertEqual(3,555 checkpoint_utils.load_variable(self.model_dir,556 self.global_step.name))557 mon_sess.run(self.train_op)558 # saved559 self.assertEqual(5,560 checkpoint_utils.load_variable(self.model_dir,561 self.global_step.name))562 def test_save_saves_at_end(self):563 with self.graph.as_default():564 hook = basic_session_run_hooks.CheckpointSaverHook(565 self.model_dir, save_secs=2, scaffold=self.scaffold)566 hook.begin()567 self.scaffold.finalize()568 with session_lib.Session() as sess:569 sess.run(self.scaffold.init_op)570 mon_sess = monitored_session._HookedSession(sess, [hook])571 mon_sess.run(self.train_op)572 mon_sess.run(self.train_op)573 hook.end(sess)574 self.assertEqual(2,575 checkpoint_utils.load_variable(self.model_dir,576 self.global_step.name))577 def test_summary_writer_defs(self):578 fake_summary_writer.FakeSummaryWriter.install()579 writer_cache.FileWriterCache.clear()580 summary_writer = writer_cache.FileWriterCache.get(self.model_dir)581 with self.graph.as_default():582 hook = basic_session_run_hooks.CheckpointSaverHook(583 self.model_dir, save_steps=2, scaffold=self.scaffold)584 hook.begin()585 self.scaffold.finalize()586 with session_lib.Session() as sess:587 sess.run(self.scaffold.init_op)588 mon_sess = monitored_session._HookedSession(sess, [hook])589 mon_sess.run(self.train_op)590 summary_writer.assert_summaries(591 test_case=self,592 expected_logdir=self.model_dir,593 expected_added_meta_graphs=[594 meta_graph.create_meta_graph_def(595 graph_def=self.graph.as_graph_def(add_shapes=True),596 saver_def=self.scaffold.saver.saver_def)597 ])598 fake_summary_writer.FakeSummaryWriter.uninstall()599class ResourceCheckpointSaverHookTest(test.TestCase):600 def setUp(self):601 self.model_dir = tempfile.mkdtemp()602 self.graph = ops.Graph()603 with self.graph.as_default():604 self.scaffold = monitored_session.Scaffold()605 with variable_scope.variable_scope('foo', use_resource=True):606 self.global_step = variables.get_or_create_global_step()607 self.train_op = state_ops.assign_add(self.global_step, 1)608 def test_save_steps_saves_periodically(self):609 with self.graph.as_default():610 hook = basic_session_run_hooks.CheckpointSaverHook(611 self.model_dir, save_steps=2, scaffold=self.scaffold)612 hook.begin()613 self.scaffold.finalize()614 with session_lib.Session() as sess:615 sess.run(self.scaffold.init_op)616 mon_sess = monitored_session._HookedSession(sess, [hook])617 mon_sess.run(self.train_op)618 mon_sess.run(self.train_op)619 # Not saved620 self.assertEqual(1,621 checkpoint_utils.load_variable(self.model_dir,622 self.global_step.name))623 mon_sess.run(self.train_op)624 # saved625 self.assertEqual(3,626 checkpoint_utils.load_variable(self.model_dir,627 self.global_step.name))628 mon_sess.run(self.train_op)629 # Not saved630 self.assertEqual(3,631 checkpoint_utils.load_variable(self.model_dir,632 self.global_step.name))633 mon_sess.run(self.train_op)634 # saved635 self.assertEqual(5,636 checkpoint_utils.load_variable(self.model_dir,637 self.global_step.name))638class StepCounterHookTest(test.TestCase):639 def setUp(self):640 self.log_dir = tempfile.mkdtemp()641 def tearDown(self):642 shutil.rmtree(self.log_dir, ignore_errors=True)643 def test_step_counter_every_n_steps(self):644 with ops.Graph().as_default() as g, session_lib.Session() as sess:645 global_step = variables.get_or_create_global_step()646 train_op = state_ops.assign_add(global_step, 1)647 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)648 hook = basic_session_run_hooks.StepCounterHook(649 summary_writer=summary_writer, every_n_steps=10)650 hook.begin()651 sess.run(variables_lib.global_variables_initializer())652 mon_sess = monitored_session._HookedSession(sess, [hook])653 for _ in range(30):654 time.sleep(0.01)655 mon_sess.run(train_op)656 hook.end(sess)657 summary_writer.assert_summaries(658 test_case=self,659 expected_logdir=self.log_dir,660 expected_graph=g,661 expected_summaries={})662 self.assertItemsEqual([11, 21], summary_writer.summaries.keys())663 for step in [11, 21]:664 summary_value = summary_writer.summaries[step][0].value[0]665 self.assertEqual('global_step/sec', summary_value.tag)666 self.assertGreater(summary_value.simple_value, 0)667 def test_step_counter_every_n_secs(self):668 with ops.Graph().as_default() as g, session_lib.Session() as sess:669 global_step = variables.get_or_create_global_step()670 train_op = state_ops.assign_add(global_step, 1)671 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)672 hook = basic_session_run_hooks.StepCounterHook(673 summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)674 hook.begin()675 sess.run(variables_lib.global_variables_initializer())676 mon_sess = monitored_session._HookedSession(sess, [hook])677 mon_sess.run(train_op)678 time.sleep(0.2)679 mon_sess.run(train_op)680 time.sleep(0.2)681 mon_sess.run(train_op)682 hook.end(sess)683 summary_writer.assert_summaries(684 test_case=self,685 expected_logdir=self.log_dir,686 expected_graph=g,687 expected_summaries={})688 self.assertTrue(summary_writer.summaries, 'No summaries were created.')689 self.assertItemsEqual([2, 3], summary_writer.summaries.keys())690 for summary in summary_writer.summaries.values():691 summary_value = summary[0].value[0]692 self.assertEqual('global_step/sec', summary_value.tag)693 self.assertGreater(summary_value.simple_value, 0)694 def test_global_step_name(self):695 with ops.Graph().as_default() as g, session_lib.Session() as sess:696 with variable_scope.variable_scope('bar'):697 foo_step = variable_scope.get_variable(698 'foo',699 initializer=0,700 trainable=False,701 collections=[702 ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES703 ])704 train_op = state_ops.assign_add(foo_step, 1)705 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)706 hook = basic_session_run_hooks.StepCounterHook(707 summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)708 hook.begin()709 sess.run(variables_lib.global_variables_initializer())710 mon_sess = monitored_session._HookedSession(sess, [hook])711 mon_sess.run(train_op)712 mon_sess.run(train_op)713 hook.end(sess)714 summary_writer.assert_summaries(715 test_case=self,716 expected_logdir=self.log_dir,717 expected_graph=g,718 expected_summaries={})719 self.assertTrue(summary_writer.summaries, 'No summaries were created.')720 self.assertItemsEqual([2], summary_writer.summaries.keys())721 summary_value = summary_writer.summaries[2][0].value[0]722 self.assertEqual('bar/foo/sec', summary_value.tag)723class SummarySaverHookTest(test.TestCase):724 def setUp(self):725 test.TestCase.setUp(self)726 self.log_dir = 'log/dir'727 self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)728 var = variables_lib.Variable(0.0)729 tensor = state_ops.assign_add(var, 1.0)730 tensor2 = tensor * 2731 self.summary_op = summary_lib.scalar('my_summary', tensor)732 self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)733 global_step = variables.get_or_create_global_step()734 self.train_op = state_ops.assign_add(global_step, 1)735 def test_raise_when_scaffold_and_summary_op_both_missing(self):736 with self.assertRaises(ValueError):737 basic_session_run_hooks.SummarySaverHook()738 def test_raise_when_scaffold_and_summary_op_both_present(self):739 with self.assertRaises(ValueError):740 basic_session_run_hooks.SummarySaverHook(741 scaffold=monitored_session.Scaffold(), summary_op=self.summary_op)742 def test_raise_in_both_secs_and_steps(self):743 with self.assertRaises(ValueError):744 basic_session_run_hooks.SummarySaverHook(745 save_secs=10, save_steps=20, summary_writer=self.summary_writer)746 def test_raise_in_none_secs_and_steps(self):747 with self.assertRaises(ValueError):748 basic_session_run_hooks.SummarySaverHook(749 save_secs=None, save_steps=None, summary_writer=self.summary_writer)750 def test_save_steps(self):751 hook = basic_session_run_hooks.SummarySaverHook(752 save_steps=8,753 summary_writer=self.summary_writer,754 summary_op=self.summary_op)755 with self.test_session() as sess:756 hook.begin()757 sess.run(variables_lib.global_variables_initializer())758 mon_sess = monitored_session._HookedSession(sess, [hook])759 for _ in range(30):760 mon_sess.run(self.train_op)761 hook.end(sess)762 self.summary_writer.assert_summaries(763 test_case=self,764 expected_logdir=self.log_dir,765 expected_summaries={766 1: {767 'my_summary': 1.0768 },769 9: {770 'my_summary': 2.0771 },772 17: {773 'my_summary': 3.0774 },775 25: {776 'my_summary': 4.0777 },778 })779 def test_multiple_summaries(self):780 hook = basic_session_run_hooks.SummarySaverHook(781 save_steps=8,782 summary_writer=self.summary_writer,783 summary_op=[self.summary_op, self.summary_op2])784 with self.test_session() as sess:785 hook.begin()786 sess.run(variables_lib.global_variables_initializer())787 mon_sess = monitored_session._HookedSession(sess, [hook])788 for _ in range(10):789 mon_sess.run(self.train_op)790 hook.end(sess)791 self.summary_writer.assert_summaries(792 test_case=self,793 expected_logdir=self.log_dir,794 expected_summaries={795 1: {796 'my_summary': 1.0,797 'my_summary2': 2.0798 },799 9: {800 'my_summary': 2.0,801 'my_summary2': 4.0802 },803 })804 def test_save_secs_saving_once_every_step(self):805 hook = basic_session_run_hooks.SummarySaverHook(806 save_secs=0.5,807 summary_writer=self.summary_writer,808 summary_op=self.summary_op)809 with self.test_session() as sess:810 hook.begin()811 sess.run(variables_lib.global_variables_initializer())812 mon_sess = monitored_session._HookedSession(sess, [hook])813 for _ in range(4):814 mon_sess.run(self.train_op)815 time.sleep(0.5)816 hook.end(sess)817 self.summary_writer.assert_summaries(818 test_case=self,819 expected_logdir=self.log_dir,820 expected_summaries={821 1: {822 'my_summary': 1.0823 },824 2: {825 'my_summary': 2.0826 },827 3: {828 'my_summary': 3.0829 },830 4: {831 'my_summary': 4.0832 },833 })834 def test_save_secs_saving_once_every_three_steps(self):835 hook = basic_session_run_hooks.SummarySaverHook(836 save_secs=0.9,837 summary_writer=self.summary_writer,838 summary_op=self.summary_op)839 with self.test_session() as sess:840 hook.begin()841 sess.run(variables_lib.global_variables_initializer())842 mon_sess = monitored_session._HookedSession(sess, [hook])843 for _ in range(8):844 mon_sess.run(self.train_op)845 time.sleep(0.3)846 hook.end(sess)847 self.summary_writer.assert_summaries(848 test_case=self,849 expected_logdir=self.log_dir,850 expected_summaries={851 1: {852 'my_summary': 1.0853 },854 4: {855 'my_summary': 2.0856 },857 7: {858 'my_summary': 3.0859 },860 })861class GlobalStepWaiterHookTest(test.TestCase):862 def test_not_wait_for_step_zero(self):863 with ops.Graph().as_default():864 variables.get_or_create_global_step()865 hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)866 hook.begin()867 with session_lib.Session() as sess:868 # Before run should return without waiting gstep increment.869 hook.before_run(870 session_run_hook.SessionRunContext(871 original_args=None, session=sess))872 def test_wait_for_step(self):873 with ops.Graph().as_default():874 gstep = variables.get_or_create_global_step()875 hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)876 hook.begin()877 with session_lib.Session() as sess:878 sess.run(variables_lib.global_variables_initializer())879 waiter = threading.Thread(880 target=hook.before_run,881 args=(session_run_hook.SessionRunContext(882 original_args=None, session=sess),))883 waiter.daemon = True884 waiter.start()885 time.sleep(1.0)886 self.assertTrue(waiter.is_alive())887 sess.run(state_ops.assign(gstep, 500))888 time.sleep(1.0)889 self.assertTrue(waiter.is_alive())890 sess.run(state_ops.assign(gstep, 1100))891 time.sleep(1.2)892 self.assertFalse(waiter.is_alive())893class FinalOpsHookTest(test.TestCase):894 def test_final_ops_is_scalar_tensor(self):895 with ops.Graph().as_default():896 expected_value = 4897 final_ops = constant_op.constant(expected_value)898 hook = basic_session_run_hooks.FinalOpsHook(final_ops)899 hook.begin()900 with session_lib.Session() as session:901 hook.end(session)902 self.assertEqual(expected_value,903 hook.final_ops_values)904 def test_final_ops_is_tensor(self):905 with ops.Graph().as_default():906 expected_values = [1, 6, 3, 5, 2, 4]907 final_ops = constant_op.constant(expected_values)908 hook = basic_session_run_hooks.FinalOpsHook(final_ops)909 hook.begin()910 with session_lib.Session() as session:911 hook.end(session)912 self.assertListEqual(expected_values,913 hook.final_ops_values.tolist())914 def test_final_ops_with_dictionary(self):915 with ops.Graph().as_default():916 expected_values = [4, -3]917 final_ops = array_ops.placeholder(dtype=dtypes.float32)918 final_ops_feed_dict = {final_ops: expected_values}919 hook = basic_session_run_hooks.FinalOpsHook(920 final_ops, final_ops_feed_dict)921 hook.begin()922 with session_lib.Session() as session:923 hook.end(session)924 self.assertListEqual(expected_values,925 hook.final_ops_values.tolist())926class ResourceSummarySaverHookTest(test.TestCase):927 def setUp(self):928 test.TestCase.setUp(self)929 self.log_dir = 'log/dir'930 self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)931 var = variable_scope.get_variable('var', initializer=0.0, use_resource=True)932 tensor = state_ops.assign_add(var, 1.0)933 self.summary_op = summary_lib.scalar('my_summary', tensor)934 with variable_scope.variable_scope('foo', use_resource=True):935 global_step = variables.get_or_create_global_step()936 self.train_op = state_ops.assign_add(global_step, 1)937 def test_save_steps(self):938 hook = basic_session_run_hooks.SummarySaverHook(939 save_steps=8,940 summary_writer=self.summary_writer,941 summary_op=self.summary_op)942 with self.test_session() as sess:943 hook.begin()944 sess.run(variables_lib.global_variables_initializer())945 mon_sess = monitored_session._HookedSession(sess, [hook])946 for _ in range(30):947 mon_sess.run(self.train_op)948 hook.end(sess)949 self.summary_writer.assert_summaries(950 test_case=self,951 expected_logdir=self.log_dir,952 expected_summaries={953 1: {954 'my_summary': 1.0955 },956 9: {957 'my_summary': 2.0958 },959 17: {960 'my_summary': 3.0961 },962 25: {963 'my_summary': 4.0964 },965 })966class FeedFnHookTest(test.TestCase):967 def test_feeding_placeholder(self):968 with ops.Graph().as_default(), session_lib.Session() as sess:969 x = array_ops.placeholder(dtype=dtypes.float32)970 y = x + 1971 hook = basic_session_run_hooks.FeedFnHook(972 feed_fn=lambda: {x: 1.0})973 hook.begin()974 mon_sess = monitored_session._HookedSession(sess, [hook])975 self.assertEqual(mon_sess.run(y), 2)976if __name__ == '__main__':...
local_cli_wrapper_test.py
Source:local_cli_wrapper_test.py
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15"""Unit tests for local command-line-interface debug wrapper session."""16from __future__ import absolute_import17from __future__ import division18from __future__ import print_function19import os20import shutil21import tempfile22from tensorflow.core.protobuf import config_pb223from tensorflow.python.client import session24from tensorflow.python.debug.cli import cli_shared25from tensorflow.python.debug.cli import debugger_cli_common26from tensorflow.python.debug.wrappers import local_cli_wrapper27from tensorflow.python.framework import constant_op28from tensorflow.python.framework import dtypes29from tensorflow.python.framework import errors30from tensorflow.python.framework import ops31from tensorflow.python.framework import test_util32from tensorflow.python.ops import array_ops33from tensorflow.python.ops import control_flow_ops34from tensorflow.python.ops import math_ops35# Import resource_variable_ops for the variables-to-tensor implicit conversion.36from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import37from tensorflow.python.ops import state_ops38from tensorflow.python.ops import variables39from tensorflow.python.platform import googletest40class LocalCLIDebuggerWrapperSessionForTest(41 local_cli_wrapper.LocalCLIDebugWrapperSession):42 """Subclasses the wrapper class for testing.43 Overrides its CLI-related methods for headless testing environments.44 Inserts observer variables for assertions.45 """46 def __init__(self,47 command_args_sequence,48 sess,49 dump_root=None):50 """Constructor of the for-test subclass.51 Args:52 command_args_sequence: (list of list of str) A list of arguments for the53 "run" command.54 sess: See the doc string of LocalCLIDebugWrapperSession.__init__.55 dump_root: See the doc string of LocalCLIDebugWrapperSession.__init__.56 """57 local_cli_wrapper.LocalCLIDebugWrapperSession.__init__(58 self, sess, dump_root=dump_root, log_usage=False)59 self._command_args_sequence = command_args_sequence60 self._response_pointer = 061 # Observer variables.62 self.observers = {63 "debug_dumps": [],64 "tf_errors": [],65 "run_start_cli_run_numbers": [],66 "run_end_cli_run_numbers": [],67 "profiler_py_graphs": [],68 "profiler_run_metadata": [],69 }70 def _prep_cli_for_run_start(self):71 pass72 def _prep_debug_cli_for_run_end(self, debug_dump, tf_error, passed_filter):73 self.observers["debug_dumps"].append(debug_dump)74 self.observers["tf_errors"].append(tf_error)75 def _prep_profile_cli_for_run_end(self, py_graph, run_metadata):76 self.observers["profiler_py_graphs"].append(py_graph)77 self.observers["profiler_run_metadata"].append(run_metadata)78 def _launch_cli(self):79 if self._is_run_start:80 self.observers["run_start_cli_run_numbers"].append(self._run_call_count)81 else:82 self.observers["run_end_cli_run_numbers"].append(self._run_call_count)83 command_args = self._command_args_sequence[self._response_pointer]84 self._response_pointer += 185 try:86 self._run_handler(command_args)87 except debugger_cli_common.CommandLineExit as e:88 response = e.exit_token89 return response90class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):91 def setUp(self):92 self._tmp_dir = tempfile.mktemp()93 self.v = variables.Variable(10.0, name="v")94 self.w = variables.Variable(21.0, name="w")95 self.delta = constant_op.constant(1.0, name="delta")96 self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")97 self.w_int = control_flow_ops.with_dependencies(98 [self.inc_v],99 math_ops.cast(self.w, dtypes.int32, name="w_int_inner"),100 name="w_int_outer")101 self.ph = array_ops.placeholder(dtypes.float32, name="ph")102 self.xph = array_ops.transpose(self.ph, name="xph")103 self.m = constant_op.constant(104 [[0.0, 1.0, 2.0], [-4.0, -1.0, 0.0]], dtype=dtypes.float32, name="m")105 self.y = math_ops.matmul(self.m, self.xph, name="y")106 self.sess = session.Session()107 # Initialize variable.108 self.sess.run(variables.global_variables_initializer())109 def tearDown(self):110 ops.reset_default_graph()111 if os.path.isdir(self._tmp_dir):112 shutil.rmtree(self._tmp_dir)113 def testConstructWrapper(self):114 local_cli_wrapper.LocalCLIDebugWrapperSession(115 session.Session(), log_usage=False)116 def testConstructWrapperWithExistingEmptyDumpRoot(self):117 os.mkdir(self._tmp_dir)118 self.assertTrue(os.path.isdir(self._tmp_dir))119 local_cli_wrapper.LocalCLIDebugWrapperSession(120 session.Session(), dump_root=self._tmp_dir, log_usage=False)121 def testConstructWrapperWithExistingNonEmptyDumpRoot(self):122 os.mkdir(self._tmp_dir)123 dir_path = os.path.join(self._tmp_dir, "foo")124 os.mkdir(dir_path)125 self.assertTrue(os.path.isdir(dir_path))126 with self.assertRaisesRegexp(127 ValueError, "dump_root path points to a non-empty directory"):128 local_cli_wrapper.LocalCLIDebugWrapperSession(129 session.Session(), dump_root=self._tmp_dir, log_usage=False)130 def testConstructWrapperWithExistingFileDumpRoot(self):131 os.mkdir(self._tmp_dir)132 file_path = os.path.join(self._tmp_dir, "foo")133 open(file_path, "a").close() # Create the file134 self.assertTrue(os.path.isfile(file_path))135 with self.assertRaisesRegexp(ValueError, "dump_root path points to a file"):136 local_cli_wrapper.LocalCLIDebugWrapperSession(137 session.Session(), dump_root=file_path, log_usage=False)138 def testRunsUnderDebugMode(self):139 # Test command sequence: run; run; run;140 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(141 [[], [], []], self.sess, dump_root=self._tmp_dir)142 # run under debug mode twice.143 wrapped_sess.run(self.inc_v)144 wrapped_sess.run(self.inc_v)145 # Verify that the assign_add op did take effect.146 self.assertAllClose(12.0, self.sess.run(self.v))147 # Assert correct run call numbers for which the CLI has been launched at148 # run-start and run-end.149 self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])150 self.assertEqual([1, 2], wrapped_sess.observers["run_end_cli_run_numbers"])151 # Verify that the dumps have been generated and picked up during run-end.152 self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))153 # Verify that the TensorFlow runtime errors are picked up and in this case,154 # they should be both None.155 self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])156 def testRunsWithEmptyStringDumpRootWorks(self):157 # Test command sequence: run, run158 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(159 [[], []], self.sess, dump_root="")160 # run under debug mode.161 wrapped_sess.run(self.inc_v)162 self.assertAllClose(11.0, self.sess.run(self.v))163 def testRunInfoOutputAtRunEndIsCorrect(self):164 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(165 [[], [], []], self.sess, dump_root=self._tmp_dir)166 wrapped_sess.run(self.inc_v)167 run_info_output = wrapped_sess._run_info_handler([])168 tfdbg_logo = cli_shared.get_tfdbg_logo()169 # The run_info output in the first run() call should contain the tfdbg logo.170 self.assertEqual(tfdbg_logo.lines,171 run_info_output.lines[:len(tfdbg_logo.lines)])172 menu = run_info_output.annotations[debugger_cli_common.MAIN_MENU_KEY]173 self.assertIn("list_tensors", menu.captions())174 wrapped_sess.run(self.inc_v)175 run_info_output = wrapped_sess._run_info_handler([])176 # The run_info output in the second run() call should NOT contain the logo.177 self.assertNotEqual(tfdbg_logo.lines,178 run_info_output.lines[:len(tfdbg_logo.lines)])179 menu = run_info_output.annotations[debugger_cli_common.MAIN_MENU_KEY]180 self.assertIn("list_tensors", menu.captions())181 def testRunsUnderNonDebugMode(self):182 # Test command sequence: run -n; run -n; run -n;183 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(184 [["-n"], ["-n"], ["-n"]], self.sess, dump_root=self._tmp_dir)185 # run three times.186 wrapped_sess.run(self.inc_v)187 wrapped_sess.run(self.inc_v)188 wrapped_sess.run(self.inc_v)189 self.assertAllClose(13.0, self.sess.run(self.v))190 self.assertEqual([1, 2, 3],191 wrapped_sess.observers["run_start_cli_run_numbers"])192 self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])193 def testRunsUnderNonDebugThenDebugMode(self):194 # Test command sequence: run -n; run -n; run; run;195 # Do two NON_DEBUG_RUNs, followed by DEBUG_RUNs.196 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(197 [["-n"], ["-n"], [], []], self.sess, dump_root=self._tmp_dir)198 # run three times.199 wrapped_sess.run(self.inc_v)200 wrapped_sess.run(self.inc_v)201 wrapped_sess.run(self.inc_v)202 self.assertAllClose(13.0, self.sess.run(self.v))203 self.assertEqual([1, 2, 3],204 wrapped_sess.observers["run_start_cli_run_numbers"])205 # Here, the CLI should have been launched only under the third run,206 # because the first and second runs are NON_DEBUG.207 self.assertEqual([3], wrapped_sess.observers["run_end_cli_run_numbers"])208 self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))209 self.assertEqual([None], wrapped_sess.observers["tf_errors"])210 def testRunMultipleTimesWithinLimit(self):211 # Test command sequence: run -t 3; run;212 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(213 [["-t", "3"], []], self.sess, dump_root=self._tmp_dir)214 # run three times.215 wrapped_sess.run(self.inc_v)216 wrapped_sess.run(self.inc_v)217 wrapped_sess.run(self.inc_v)218 self.assertAllClose(13.0, self.sess.run(self.v))219 self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])220 self.assertEqual([3], wrapped_sess.observers["run_end_cli_run_numbers"])221 self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))222 self.assertEqual([None], wrapped_sess.observers["tf_errors"])223 def testRunMultipleTimesOverLimit(self):224 # Test command sequence: run -t 3;225 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(226 [["-t", "3"]], self.sess, dump_root=self._tmp_dir)227 # run twice, which is less than the number of times specified by the228 # command.229 wrapped_sess.run(self.inc_v)230 wrapped_sess.run(self.inc_v)231 self.assertAllClose(12.0, self.sess.run(self.v))232 self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])233 self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])234 self.assertEqual(0, len(wrapped_sess.observers["debug_dumps"]))235 self.assertEqual([], wrapped_sess.observers["tf_errors"])236 def testRunMixingDebugModeAndMultpleTimes(self):237 # Test command sequence: run -n; run -t 2; run; run;238 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(239 [["-n"], ["-t", "2"], [], []], self.sess, dump_root=self._tmp_dir)240 # run four times.241 wrapped_sess.run(self.inc_v)242 wrapped_sess.run(self.inc_v)243 wrapped_sess.run(self.inc_v)244 wrapped_sess.run(self.inc_v)245 self.assertAllClose(14.0, self.sess.run(self.v))246 self.assertEqual([1, 2],247 wrapped_sess.observers["run_start_cli_run_numbers"])248 self.assertEqual([3, 4], wrapped_sess.observers["run_end_cli_run_numbers"])249 self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))250 self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])251 def testDebuggingMakeCallableTensorRunnerWorks(self):252 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(253 [[], []], self.sess, dump_root=self._tmp_dir)254 v = variables.Variable(42)255 tensor_runner = wrapped_sess.make_callable(v)256 self.sess.run(v.initializer)257 self.assertAllClose(42, tensor_runner())258 self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))259 def testDebuggingMakeCallableTensorRunnerWithCustomRunOptionsWorks(self):260 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(261 [[], []], self.sess, dump_root=self._tmp_dir)262 a = constant_op.constant(42)263 tensor_runner = wrapped_sess.make_callable(a)264 run_options = config_pb2.RunOptions(265 trace_level=config_pb2.RunOptions.FULL_TRACE)266 run_metadata = config_pb2.RunMetadata()267 self.assertAllClose(268 42, tensor_runner(options=run_options, run_metadata=run_metadata))269 self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))270 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)271 def testDebuggingMakeCallableOperationRunnerWorks(self):272 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(273 [[], []], self.sess, dump_root=self._tmp_dir)274 v = variables.Variable(10.0)275 inc_v = state_ops.assign_add(v, 1.0)276 op_runner = wrapped_sess.make_callable(inc_v.op)277 self.sess.run(v.initializer)278 op_runner()279 self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))280 self.assertEqual(11.0, self.sess.run(v))281 def testDebuggingMakeCallableRunnerWithFeedListWorks(self):282 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(283 [[], []], self.sess, dump_root=self._tmp_dir)284 ph1 = array_ops.placeholder(dtypes.float32)285 ph2 = array_ops.placeholder(dtypes.float32)286 a = math_ops.add(ph1, ph2)287 tensor_runner = wrapped_sess.make_callable(a, feed_list=[ph1, ph2])288 self.assertAllClose(42.0, tensor_runner(41.0, 1.0))289 self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))290 def testRuntimeErrorShouldBeCaught(self):291 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(292 [[], []], self.sess, dump_root=self._tmp_dir)293 # Do a run that should lead to an TensorFlow runtime error.294 wrapped_sess.run(self.y, feed_dict={self.ph: [[0.0], [1.0], [2.0]]})295 self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])296 self.assertEqual([1], wrapped_sess.observers["run_end_cli_run_numbers"])297 self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))298 # Verify that the runtime error is caught by the wrapped session properly.299 self.assertEqual(1, len(wrapped_sess.observers["tf_errors"]))300 tf_error = wrapped_sess.observers["tf_errors"][0]301 self.assertEqual("y", tf_error.op.name)302 def testRuntimeErrorBeforeGraphExecutionIsRaised(self):303 # Use an impossible device name to cause an error before graph execution.304 with ops.device("/gpu:1337"):305 w = variables.Variable([1.0] * 10, name="w")306 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(307 [[]], self.sess, dump_root=self._tmp_dir)308 with self.assertRaisesRegexp(errors.OpError, r".*[Dd]evice.*1337.*"):309 wrapped_sess.run(w)310 def testRunTillFilterPassesShouldLaunchCLIAtCorrectRun(self):311 # Test command sequence:312 # run -f greater_than_twelve; run -f greater_than_twelve; run;313 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(314 [["-f", "v_greater_than_twelve"], ["-f", "v_greater_than_twelve"], []],315 self.sess,316 dump_root=self._tmp_dir)317 def v_greater_than_twelve(datum, tensor):318 return datum.node_name == "v" and tensor > 12.0319 wrapped_sess.add_tensor_filter("v_greater_than_twelve",320 v_greater_than_twelve)321 # run five times.322 wrapped_sess.run(self.inc_v)323 wrapped_sess.run(self.inc_v)324 wrapped_sess.run(self.inc_v)325 wrapped_sess.run(self.inc_v)326 wrapped_sess.run(self.inc_v)327 self.assertAllClose(15.0, self.sess.run(self.v))328 self.assertEqual([1], wrapped_sess.observers["run_start_cli_run_numbers"])329 # run-end CLI should NOT have been launched for run #2 and #3, because only330 # starting from run #4 v becomes greater than 12.0.331 self.assertEqual([4, 5], wrapped_sess.observers["run_end_cli_run_numbers"])332 self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))333 self.assertEqual([None, None], wrapped_sess.observers["tf_errors"])334 def testRunsUnderDebugModeWithWatchFnFilteringNodeNames(self):335 # Test command sequence:336 # run --node_name_filter inc.*337 # run --node_name_filter delta338 # run339 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(340 [["--node_name_filter", "inc.*"], ["--node_name_filter", "delta"], []],341 self.sess, dump_root=self._tmp_dir)342 # run under debug mode twice.343 wrapped_sess.run(self.inc_v)344 wrapped_sess.run(self.inc_v)345 # Verify that the assign_add op did take effect.346 self.assertAllClose(12.0, self.sess.run(self.v))347 # Verify that the dumps have been generated and picked up during run-end.348 self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))349 dumps = wrapped_sess.observers["debug_dumps"][0]350 self.assertEqual(1, dumps.size)351 self.assertEqual("inc_v", dumps.dumped_tensor_data[0].node_name)352 dumps = wrapped_sess.observers["debug_dumps"][1]353 self.assertEqual(1, dumps.size)354 self.assertEqual("delta", dumps.dumped_tensor_data[0].node_name)355 def testRunsUnderDebugModeWithWatchFnFilteringOpTypes(self):356 # Test command sequence:357 # run --node_name_filter delta358 # run --op_type_filter AssignAdd359 # run360 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(361 [["--node_name_filter", "delta"],362 ["--op_type_filter", "AssignAdd"],363 []],364 self.sess, dump_root=self._tmp_dir)365 # run under debug mode twice.366 wrapped_sess.run(self.inc_v)367 wrapped_sess.run(self.inc_v)368 # Verify that the assign_add op did take effect.369 self.assertAllClose(12.0, self.sess.run(self.v))370 # Verify that the dumps have been generated and picked up during run-end.371 self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))372 dumps = wrapped_sess.observers["debug_dumps"][0]373 self.assertEqual(1, dumps.size)374 self.assertEqual("delta", dumps.dumped_tensor_data[0].node_name)375 dumps = wrapped_sess.observers["debug_dumps"][1]376 self.assertEqual(1, dumps.size)377 self.assertEqual("inc_v", dumps.dumped_tensor_data[0].node_name)378 def testRunsUnderDebugModeWithWatchFnFilteringTensorDTypes(self):379 # Test command sequence:380 # run --op_type_filter Variable.*381 # run --dtype_filter int32382 # run383 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(384 [["--op_type_filter", "Variable.*"],385 ["--tensor_dtype_filter", "int32"], []],386 self.sess, dump_root=self._tmp_dir)387 # run under debug mode twice.388 wrapped_sess.run(self.w_int)389 wrapped_sess.run(self.w_int)390 # Verify that the dumps have been generated and picked up during run-end.391 self.assertEqual(2, len(wrapped_sess.observers["debug_dumps"]))392 dumps = wrapped_sess.observers["debug_dumps"][0]393 self.assertEqual(2, dumps.size)394 self.assertItemsEqual(395 ["v", "w"], [dumps.dumped_tensor_data[i].node_name for i in [0, 1]])396 dumps = wrapped_sess.observers["debug_dumps"][1]397 self.assertEqual(2, dumps.size)398 self.assertEqual(399 ["w_int_inner", "w_int_outer"],400 [dumps.dumped_tensor_data[i].node_name for i in [0, 1]])401 def testRunsUnderDebugModeWithWatchFnFilteringOpTypesAndTensorDTypes(self):402 # Test command sequence:403 # run --op_type_filter Cast --dtype_filter int32404 # run405 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(406 [["--op_type_filter", "Cast", "--tensor_dtype_filter", "int32"], []],407 self.sess, dump_root=self._tmp_dir)408 # run under debug mode twice.409 wrapped_sess.run(self.w_int)410 # Verify that the dumps have been generated and picked up during run-end.411 self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))412 dumps = wrapped_sess.observers["debug_dumps"][0]413 self.assertEqual(1, dumps.size)414 self.assertEqual("w_int_inner", dumps.dumped_tensor_data[0].node_name)415 def testRunUnderProfilerModeWorks(self):416 wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(417 [["-p"], []], self.sess)418 wrapped_sess.run(self.w_int)419 self.assertEqual(1, len(wrapped_sess.observers["profiler_run_metadata"]))420 self.assertTrue(421 wrapped_sess.observers["profiler_run_metadata"][0].step_stats)422 self.assertEqual(1, len(wrapped_sess.observers["profiler_py_graphs"]))423 self.assertIsInstance(424 wrapped_sess.observers["profiler_py_graphs"][0], ops.Graph)425if __name__ == "__main__":...
backend.py
Source:backend.py
1import asyncio, time2from collections import defaultdict3from enum import IntFlag4from util.misc import gen_uuid, EMPTY_SET, run_loop5from .user import UserService6from .auth import AuthService7from .stats import Stats8from .models import User, Group, Lst, Contact, UserStatus9from . import error, event10class Ack(IntFlag):11 Zero = 012 NAK = 113 ACK = 214 Full = 315class Backend:16 def __init__(self, loop, *, user_service = None, auth_service = None):17 self._loop = loop18 self._user_service = user_service or UserService()19 self._auth_service = auth_service or AuthService()20 self._stats = Stats()21 22 self._sc = _SessionCollection()23 # Dict[User.uuid, User]24 self._user_by_uuid = {}25 # Dict[User, UserDetail]26 self._unsynced_db = {}27 28 # Dict[chatid, Chat]29 self._chats = {}30 31 self._runners = []32 33 loop.create_task(self._sync_db())34 loop.create_task(self._clean_sessions())35 loop.create_task(self._sync_stats())36 37 def add_runner(self, runner):38 self._runners.append(runner)39 40 def run_forever(self):41 run_loop(self._loop, self._runners)42 43 def on_leave(self, sess):44 user = sess.user45 if user is None: return46 self._stats.on_logout()47 self._sc.remove_session(sess)48 if self._sc.get_sessions_by_user(user):49 # There are still other people logged in as this user,50 # so don't send offline notifications.51 return52 # User is offline, send notifications53 user.detail = None54 self._sync_contact_statuses()55 self._generic_notify(sess)56 57 def login_md5_get_salt(self, email):58 return self._user_service.get_md5_salt(email)59 60 def login_md5_verify(self, sess, email, md5_hash):61 uuid = self._user_service.login_md5(email, md5_hash)62 return self._login_common(sess, uuid, email)63 64 def login_twn_start(self, email, password):65 uuid = self._user_service.login(email, password)66 if uuid is None: return None67 return self._auth_service.create_token('nb/login', uuid)68 69 def login_twn_verify(self, sess, email, token):70 uuid = self._auth_service.pop_token('nb/login', token)71 return self._login_common(sess, uuid, email)72 73 def login_IKWIAD(self, sess, email):74 uuid = self.util_get_uuid_from_email(email)75 return self._login_common(sess, uuid, email)76 77 def _login_common(self, sess, uuid, email):78 if uuid is None: return None79 self._user_service.update_date_login(uuid)80 user = self._load_user_record(uuid)81 sess.user = user82 self._stats.on_login()83 self._stats.on_user_active(user, sess.client)84 self._sc.add_session(sess)85 user.detail = self._load_detail(user)86 return user87 88 def _load_user_record(self, uuid):89 if uuid not in self._user_by_uuid:90 user = self._user_service.get(uuid)91 if user is None: return None92 self._user_by_uuid[uuid] = user93 return self._user_by_uuid[uuid]94 95 def _load_detail(self, user):96 if user.detail: return user.detail97 return self._user_service.get_detail(user.uuid)98 99 def _generic_notify(self, sess):100 # Notify relevant `Session`s of status, name, message, media101 user = sess.user102 if user is None: return103 # TODO: This does a lot of work, iterating through _every_ session.104 # If RL is set up properly, could iterate through `user.detail.contacts`.105 for sess_other in self._sc.iter_sessions():106 if sess_other == sess: continue107 user_other = sess_other.user108 if user_other is None: continue109 if user_other.detail is None: continue110 ctc = user_other.detail.contacts.get(user.uuid)111 if ctc is None: continue112 sess_other.send_event(event.PresenceNotificationEvent(ctc))113 114 def _sync_contact_statuses(self):115 # Recompute all `Contact.status`'s116 for user in self._user_by_uuid.values():117 detail = user.detail118 if detail is None: continue119 for ctc in detail.contacts.values():120 ctc.compute_visible_status(user)121 122 def _mark_modified(self, user, *, detail = None):123 ud = user.detail or detail124 if detail: assert ud is detail125 assert ud is not None126 self._unsynced_db[user] = ud127 128 def sb_token_create(self, sess, *, extra_data = None):129 if extra_data is None:130 extra_data = {}131 extra_data['client'] = sess.client132 return self._auth_service.create_token('sb/xfr', { 'uuid': sess.user.uuid, 'extra_data': extra_data })133 134 def me_update(self, sess, fields):135 user = sess.user136 137 if 'message' in fields:138 user.status.message = fields['message']139 if 'media' in fields:140 user.status.media = fields['media']141 if 'name' in fields:142 user.status.name = fields['name']143 if 'gtc' in fields:144 user.detail.settings['gtc'] = fields['gtc']145 if 'blp' in fields:146 user.detail.settings['blp'] = fields['blp']147 if 'substatus' in fields:148 user.status.substatus = fields['substatus']149 150 self._mark_modified(user)151 self._sync_contact_statuses()152 self._generic_notify(sess)153 154 def me_group_add(self, sess, name, *, is_favorite = None):155 if len(name) > MAX_GROUP_NAME_LENGTH:156 raise error.GroupNameTooLong()157 user = sess.user158 group = Group(_gen_group_id(user.detail), name, is_favorite = is_favorite)159 user.detail.groups[group.id] = group160 self._mark_modified(user)161 return group162 163 def me_group_remove(self, sess, group_id):164 if group_id == '0':165 raise error.CannotRemoveSpecialGroup()166 user = sess.user167 try:168 del user.detail.groups[group_id]169 except KeyError:170 raise error.GroupDoesNotExist()171 for ctc in user.detail.contacts.values():172 ctc.groups.discard(group_id)173 self._mark_modified(user)174 175 def me_group_edit(self, sess, group_id, new_name, *, is_favorite = None):176 user = sess.user177 g = user.detail.groups.get(group_id)178 if g is None:179 raise error.GroupDoesNotExist()180 if new_name is not None:181 if len(new_name) > MAX_GROUP_NAME_LENGTH:182 raise error.GroupNameTooLong()183 g.new_name = new_name184 if is_favorite is not None:185 g.is_favorite = is_favorite186 self._mark_modified(user)187 188 def me_group_contact_add(self, sess, group_id, contact_uuid):189 if group_id == '0': return190 user = sess.user191 detail = user.detail192 if group_id not in detail.groups:193 raise error.GroupDoesNotExist()194 ctc = detail.contacts.get(contact_uuid)195 if ctc is None:196 raise error.ContactDoesNotExist()197 if group_id in ctc.groups:198 raise error.ContactAlreadyOnList()199 ctc.groups.add(group_id)200 self._mark_modified(user)201 202 def me_group_contact_remove(self, sess, group_id, contact_uuid):203 user = sess.user204 detail = user.detail205 ctc = detail.contacts.get(contact_uuid)206 if ctc is None:207 raise error.ContactDoesNotExist()208 if group_id not in detail.groups and group_id != '0':209 raise error.GroupDoesNotExist()210 try:211 ctc.groups.remove(group_id)212 except KeyError:213 if group_id == '0':214 raise error.ContactNotOnList()215 self._mark_modified(user)216 217 def me_contact_add(self, sess, contact_uuid, lst, name):218 ctc_head = self._load_user_record(contact_uuid)219 if ctc_head is None:220 raise error.UserDoesNotExist()221 user = sess.user222 ctc = self._add_to_list(user, ctc_head, lst, name)223 if lst is Lst.FL:224 # FL needs a matching RL on the contact225 self._add_to_list(ctc_head, user, Lst.RL, user.status.name)226 self._notify_reverse_add(sess, ctc_head)227 self._sync_contact_statuses()228 self._generic_notify(sess)229 return ctc, ctc_head230 231 def _notify_reverse_add(self, sess, user_added):232 user_adder = sess.user233 # `user_added` was added to `user_adder`'s RL234 for sess_added in self._sc.get_sessions_by_user(user_added):235 if sess_added == sess: continue236 sess_added.send_event(event.AddedToListEvent(Lst.RL, user_adder))237 238 def me_contact_edit(self, sess, contact_uuid, *, is_messenger_user = None):239 user = sess.user240 ctc = user.detail.contacts.get(contact_uuid)241 if ctc is None:242 raise error.ContactDoesNotExist()243 if is_messenger_user is not None:244 ctc.is_messenger_user = is_messenger_user245 self._mark_modified(user)246 247 def me_contact_remove(self, sess, contact_uuid, lst):248 user = sess.user249 ctc = user.detail.contacts.get(contact_uuid)250 if ctc is None:251 raise error.ContactDoesNotExist()252 if lst is Lst.FL:253 # Remove from FL254 self._remove_from_list(user, ctc.head, Lst.FL)255 # Remove matching RL256 self._remove_from_list(ctc.head, user, Lst.RL)257 else:258 assert lst is not Lst.RL259 ctc.lists &= ~lst260 self._mark_modified(user)261 self._sync_contact_statuses()262 263 def _add_to_list(self, user, ctc_head, lst, name):264 # Add `ctc_head` to `user`'s `lst`265 detail = self._load_detail(user)266 contacts = detail.contacts267 if ctc_head.uuid not in contacts:268 contacts[ctc_head.uuid] = Contact(ctc_head, set(), 0, UserStatus(name))269 ctc = contacts.get(ctc_head.uuid)270 if ctc.status.name is None:271 ctc.status.name = name272 ctc.lists |= lst273 self._mark_modified(user, detail = detail)274 return ctc275 276 def _remove_from_list(self, user, ctc_head, lst):277 # Remove `ctc_head` from `user`'s `lst`278 detail = self._load_detail(user)279 contacts = detail.contacts280 ctc = contacts.get(ctc_head.uuid)281 if ctc is None: return282 ctc.lists &= ~lst283 if not ctc.lists:284 del contacts[ctc_head.uuid]285 self._mark_modified(user, detail = detail)286 287 def me_pop_boot_others(self, sess):288 for sess_other in self._sc.get_sessions_by_user(sess.user):289 if sess is sess_other: continue290 sess_other.send_event(event.POPBootEvent())291 292 def me_pop_notify_others(self, sess):293 for sess_other in self._sc.get_sessions_by_user(sess.user):294 if sess is sess_other: continue295 sess_other.send_event(event.POPNotifyEvent())296 297 def login_xfr(self, sess, email, token):298 (user, extra_data) = self._load_user('sb/xfr', token)299 if user is None: return None300 if user.email != email: return None301 sess.user = user302 sess.client = extra_data['client']303 chat = Chat(self._stats)304 self._chats[chat.id] = chat305 chat.add_session(sess)306 return chat, extra_data307 308 def login_cal(self, sess, email, token, chatid):309 (user, extra_data) = self._load_user('sb/cal', token)310 if user is None: return None311 if user.email != email: return None312 sess.user = user313 sess.client = extra_data['client']314 chat = self._chats.get(chatid)315 if chat is None: return None316 chat.add_session(sess)317 return chat, extra_data318 319 def _load_user(self, purpose, token):320 data = self._auth_service.pop_token(purpose, token)321 if data is None: return (None, None)322 return (self._user_service.get(data['uuid']), data['extra_data'])323 324 def util_get_uuid_from_email(self, email):325 return self._user_service.get_uuid(email)326 327 def util_set_sess_token(self, sess, token):328 self._sc.set_nc_by_token(sess, token)329 330 def util_get_sess_by_token(self, token):331 return self._sc.get_nc_by_token(token)332 333 def util_get_sessions_by_user(self, user):334 return self._sc.get_sessions_by_user(user)335 336 def notify_call(self, caller_uuid, callee_email, chatid):337 caller = self._user_by_uuid.get(caller_uuid)338 if caller is None: raise error.ServerError()339 if caller.detail is None: raise error.ServerError()340 callee_uuid = self._user_service.get_uuid(callee_email)341 if callee_uuid is None: raise error.UserDoesNotExist()342 ctc = caller.detail.contacts.get(callee_uuid)343 if ctc is None:344 if callee_uuid != caller_uuid: raise error.ContactDoesNotExist()345 ctc_user = caller346 else:347 if ctc.status.is_offlineish(): raise error.ContactNotOnline()348 ctc_user = ctc.head349 ctc_sessions = self._sc.get_sessions_by_user(ctc_user)350 if not ctc_sessions: raise error.ContactNotOnline()351 352 for ctc_sess in ctc_sessions:353 extra_data = ctc_sess.state.get_sb_extra_data() or {}354 extra_data['client'] = ctc_sess.client355 token = self._auth_service.create_token('sb/cal', { 'uuid': ctc_user.uuid, 'extra_data': extra_data })356 ctc_sess.send_event(event.InvitedToChatEvent(chatid, token, caller))357 358 async def _sync_db(self):359 while True:360 await asyncio.sleep(1)361 self._sync_db_impl()362 363 def _sync_db_impl(self):364 if not self._unsynced_db: return365 try:366 users = list(self._unsynced_db.keys())[:100]367 batch = []368 for user in users:369 detail = self._unsynced_db.pop(user, None)370 if not detail: continue371 batch.append((user, detail))372 self._user_service.save_batch(batch)373 except Exception:374 import traceback375 traceback.print_exc()376 377 async def _clean_sessions(self):378 from .session import PollingSession379 while True:380 await asyncio.sleep(10)381 now = time.time()382 closed = []383 384 try:385 for sess in self._sc.iter_sessions():386 if sess.closed:387 closed.append(sess)388 continue389 if isinstance(sess, PollingSession):390 if now >= sess.time_last_connect + sess.timeout:391 sess.close()392 closed.append(sess)393 except Exception:394 import traceback395 traceback.print_exc()396 397 for sess in closed:398 self._sc.remove_session(sess)399 400 async def _sync_stats(self):401 while True:402 await asyncio.sleep(60)403 try:404 self._stats.flush()405 except Exception:406 import traceback407 traceback.print_exc()408class _SessionCollection:409 def __init__(self):410 # Set[Session]411 self._sessions = set()412 # Dict[User, Set[Session]]413 self._sessions_by_user = defaultdict(set)414 # Dict[str, Session]415 self._sess_by_token = {}416 # Dict[Session, Set[str]]417 self._tokens_by_sess = defaultdict(set)418 419 def get_sessions_by_user(self, user):420 if user not in self._sessions_by_user:421 return EMPTY_SET422 return self._sessions_by_user[user]423 424 def iter_sessions(self):425 yield from self._sessions426 427 def set_nc_by_token(self, sess, token: str):428 self._sess_by_token[token] = sess429 self._tokens_by_sess[sess].add(sess)430 self._sessions.add(sess)431 432 def get_nc_by_token(self, token: str):433 return self._sess_by_token.get(token)434 435 def add_session(self, sess):436 if sess.user:437 self._sessions_by_user[sess.user].add(sess)438 self._sessions.add(sess)439 440 def remove_session(self, sess):441 if sess in self._tokens_by_sess:442 tokens = self._tokens_by_sess.pop(sess)443 for token in tokens:444 self._sess_by_token.pop(token, None)445 self._sessions.discard(sess)446 if sess.user in self._sessions_by_user:447 self._sessions_by_user[sess.user].discard(sess)448class Chat:449 def __init__(self, stats):450 self.id = gen_uuid()451 # Dict[Session, User]452 self._users_by_sess = {}453 self._stats = stats454 455 def add_session(self, sess):456 self._users_by_sess[sess] = sess.user457 458 def send_message_to_everyone(self, sess_sender, data):459 self._stats.on_message_sent(sess_sender.user, sess_sender.client)460 self._stats.on_user_active(sess_sender.user, sess_sender.client)461 su_sender = self._users_by_sess[sess_sender]462 for sess in self._users_by_sess.keys():463 if sess == sess_sender: continue464 sess.send_event(event.ChatMessage(su_sender, data))465 self._stats.on_message_received(sess.user, sess.client)466 467 def get_roster(self, sess):468 roster = []469 for sess1, su1 in self._users_by_sess.items():470 if sess1 == sess: continue471 roster.append((sess1, su1))472 return roster473 474 def send_participant_joined(self, sess):475 for sc, _ in self.get_roster(self):476 sc.send_event(event.ChatParticipantJoined(sess))477 478 def on_leave(self, sess):479 su = self._users_by_sess.pop(sess, None)480 if su is None: return481 # Notify others that `sess` has left482 for sess1, su1 in self._users_by_sess.items():483 if sess1 == sess: continue484 sess1.send_event(event.ChatParticipantLeft(su))485def _gen_group_id(detail):486 id = 1487 s = str(id)488 while s in detail.groups:489 id += 1490 s = str(id)491 return s...
gen_dsin_input.py
Source:gen_dsin_input.py
1import os2import numpy as np3import pandas as pd4from deepctr.utils import SingleFeat5from sklearn.preprocessing import LabelEncoder, StandardScaler6from tensorflow.python.keras.preprocessing.sequence import pad_sequences7from tqdm import tqdm8from config import DSIN_SESS_COUNT, DSIN_SESS_MAX_LEN, FRAC9FRAC = FRAC10SESS_COUNT = DSIN_SESS_COUNT11def gen_sess_feature_dsin(row):12 sess_count = DSIN_SESS_COUNT13 sess_max_len = DSIN_SESS_MAX_LEN14 sess_input_dict = {}15 sess_input_length_dict = {}16 for i in range(sess_count):17 sess_input_dict['sess_' + str(i)] = {'cate_id': [], 'brand': []}18 sess_input_length_dict['sess_' + str(i)] = 019 sess_length = 020 user, time_stamp = row[1]['user'], row[1]['time_stamp']21 # sample_time = pd.to_datetime(timestamp_datetime(time_stamp ))22 if user not in user_hist_session:23 for i in range(sess_count):24 sess_input_dict['sess_' + str(i)]['cate_id'] = [0]25 sess_input_dict['sess_' + str(i)]['brand'] = [0]26 sess_input_length_dict['sess_' + str(i)] = 027 sess_length = 028 else:29 valid_sess_count = 030 last_sess_idx = len(user_hist_session[user]) - 131 for i in reversed(range(len(user_hist_session[user]))):32 cur_sess = user_hist_session[user][i]33 if cur_sess[0][2] < time_stamp:34 in_sess_count = 135 for j in range(1, len(cur_sess)):36 if cur_sess[j][2] < time_stamp:37 in_sess_count += 138 if in_sess_count > 2:39 sess_input_dict['sess_0']['cate_id'] = [e[0] for e in cur_sess[max(0,40 in_sess_count - sess_max_len):in_sess_count]]41 sess_input_dict['sess_0']['brand'] = [e[1] for e in42 cur_sess[max(0, in_sess_count - sess_max_len):in_sess_count]]43 sess_input_length_dict['sess_0'] = min(44 sess_max_len, in_sess_count)45 last_sess_idx = i46 valid_sess_count += 147 break48 for i in range(1, sess_count):49 if last_sess_idx - i >= 0:50 cur_sess = user_hist_session[user][last_sess_idx - i]51 sess_input_dict['sess_' + str(i)]['cate_id'] = [e[0]52 for e in cur_sess[-sess_max_len:]]53 sess_input_dict['sess_' + str(i)]['brand'] = [e[1]54 for e in cur_sess[-sess_max_len:]]55 sess_input_length_dict['sess_' +56 str(i)] = min(sess_max_len, len(cur_sess))57 valid_sess_count += 158 else:59 sess_input_dict['sess_' + str(i)]['cate_id'] = [0]60 sess_input_dict['sess_' + str(i)]['brand'] = [0]61 sess_input_length_dict['sess_' + str(i)] = 062 sess_length = valid_sess_count63 return sess_input_dict, sess_input_length_dict, sess_length64if __name__ == "__main__":65 user_hist_session = {}66 FILE_NUM = len(67 list(filter(lambda x: x.startswith('user_hist_session_' + str(FRAC) + '_dsin_'),68 os.listdir('../sampled_data/'))))69 print('total', FILE_NUM, 'files')70 for i in range(FILE_NUM):71 user_hist_session_ = pd.read_pickle(72 '../sampled_data/user_hist_session_' + str(FRAC) + '_dsin_' + str(i) + '.pkl') # 19,3473 user_hist_session.update(user_hist_session_)74 del user_hist_session_75 sample_sub = pd.read_pickle(76 '../sampled_data/raw_sample_' + str(FRAC) + '.pkl')77 index_list = []78 sess_input_dict = {}79 sess_input_length_dict = {}80 for i in range(SESS_COUNT):81 sess_input_dict['sess_' + str(i)] = {'cate_id': [], 'brand': []}82 sess_input_length_dict['sess_' + str(i)] = []83 sess_length_list = []84 for row in tqdm(sample_sub[['user', 'time_stamp']].iterrows()):85 sess_input_dict_, sess_input_length_dict_, sess_length = gen_sess_feature_dsin(86 row)87 # index_list.append(index)88 for i in range(SESS_COUNT):89 sess_name = 'sess_' + str(i)90 sess_input_dict[sess_name]['cate_id'].append(91 sess_input_dict_[sess_name]['cate_id'])92 sess_input_dict[sess_name]['brand'].append(93 sess_input_dict_[sess_name]['brand'])94 sess_input_length_dict[sess_name].append(95 sess_input_length_dict_[sess_name])96 sess_length_list.append(sess_length)97 print('done')98 user = pd.read_pickle('../sampled_data/user_profile_' + str(FRAC) + '.pkl')99 ad = pd.read_pickle('../sampled_data/ad_feature_enc_' + str(FRAC) + '.pkl')100 user = user.fillna(-1)101 user.rename(102 columns={'new_user_class_level ': 'new_user_class_level'}, inplace=True)103 sample_sub = pd.read_pickle(104 '../sampled_data/raw_sample_' + str(FRAC) + '.pkl')105 sample_sub.rename(columns={'user': 'userid'}, inplace=True)106 data = pd.merge(sample_sub, user, how='left', on='userid', )107 data = pd.merge(data, ad, how='left', on='adgroup_id')108 sparse_features = ['userid', 'adgroup_id', 'pid', 'cms_segid', 'cms_group_id', 'final_gender_code', 'age_level',109 'pvalue_level', 'shopping_level', 'occupation', 'new_user_class_level', 'campaign_id',110 'customer'] # sparse feature for user and ads111 dense_features = ['price'] # dense feature for user and ads112 for feat in tqdm(sparse_features):113 lbe = LabelEncoder() # or Hash114 data[feat] = lbe.fit_transform(data[feat]) # å°ä¸åçåå¼è½¬æ¢ä¸ºå¯¹åºçç¼å·115 mms = StandardScaler()116 data[dense_features] = mms.fit_transform(data[dense_features])117 # class SingleFeat(namedtuple('SingleFeat', ['name', 'dimension', 'hash_flag', 'dtype'])):118 sparse_feature_list = [SingleFeat(feat, data[feat].nunique(119 ) + 1) for feat in sparse_features + ['cate_id', 'brand']]120 dense_feature_list = [SingleFeat(feat, 1) for feat in dense_features]121 sess_feature = ['cate_id', 'brand'] # sess feature for ad122 sess_input = []123 sess_input_length = []124 for i in tqdm(range(SESS_COUNT)):125 sess_name = 'sess_' + str(i)126 for feat in sess_feature:127 sess_input.append(pad_sequences(128 sess_input_dict[sess_name][feat], maxlen=SESS_COUNT, padding='post'))129 sess_input_length.append(sess_input_length_dict[sess_name])130 model_input = [data[feat.name].values for feat in sparse_feature_list] + \131 [data[feat.name].values for feat in dense_feature_list]132 sess_lists = sess_input + [np.array(sess_length_list)]133 model_input += sess_lists134 if not os.path.exists('../model_input/'):135 os.mkdir('../model_input/')136 pd.to_pickle(model_input, '../model_input/dsin_input_' +137 str(FRAC) + '_' + str(SESS_COUNT) + '.pkl')138 pd.to_pickle(data['clk'].values, '../model_input/dsin_label_' +139 str(FRAC) + '_' + str(SESS_COUNT) + '.pkl')140 pd.to_pickle({'sparse': sparse_feature_list, 'dense': dense_feature_list},141 '../model_input/dsin_fd_' + str(FRAC) + '_' + str(SESS_COUNT) + '.pkl')...
Looking for an in-depth tutorial around pytest? LambdaTest covers the detailed pytest tutorial that has everything related to the pytest, from setting up the pytest framework to automation testing. Delve deeper into pytest testing by exploring advanced use cases like parallel testing, pytest fixtures, parameterization, executing multiple test cases from a single file, and more.
Skim our below pytest tutorial playlist to get started with automation testing using the pytest framework.
https://www.youtube.com/playlist?list=PLZMWkkQEwOPlcGgDmHl8KkXKeLF83XlrP
Get 100 minutes of automation test minutes FREE!!