Best Python code snippet using molotov_python
controller_test.py
Source:controller_test.py
1# Copyright 2021 The Orbit Authors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");4# you may not use this file except in compliance with the License.5# You may obtain a copy of the License at6#7# http://www.apache.org/licenses/LICENSE-2.08#9# Unless required by applicable law or agreed to in writing, software10# distributed under the License is distributed on an "AS IS" BASIS,11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14"""Tests for orbit.controller."""15import os16from absl import logging17from absl.testing import parameterized18import numpy as np19from orbit import controller20from orbit import runner21from orbit import standard_runner22import tensorflow as tf23def create_model():24 x = tf.keras.layers.Input(shape=(3,), name="input")25 y = tf.keras.layers.Dense(4, name="dense")(x)26 model = tf.keras.Model(x, y)27 return model28def summaries_with_matching_keyword(keyword, summary_dir):29 """Returns summary protos matching given keyword from event file."""30 matches = []31 event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))32 for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):33 if event.summary is not None:34 for value in event.summary.value:35 if keyword in value.tag:36 matches.append(event.summary)37 return matches38def dataset_fn(ctx):39 del ctx40 inputs = np.zeros((10, 3), dtype=np.float32)41 targets = np.ones((10, 4), dtype=np.float32)42 dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))43 dataset = dataset.repeat(100)44 dataset = dataset.batch(10, drop_remainder=True)45 return dataset46class TestRunner(standard_runner.StandardTrainer,47 standard_runner.StandardEvaluator):48 """Implements the training and evaluation APIs for the test model."""49 def __init__(self, return_numpy=False):50 self.strategy = tf.distribute.get_strategy()51 self.model = create_model()52 self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)53 self.global_step = self.optimizer.iterations54 self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)55 self.eval_loss = tf.keras.metrics.Mean("eval_loss", dtype=tf.float32)56 self.return_numpy = return_numpy57 train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)58 eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)59 standard_runner.StandardTrainer.__init__(self, train_dataset)60 standard_runner.StandardEvaluator.__init__(self, eval_dataset)61 def train_step(self, iterator):62 def _replicated_step(inputs):63 """Replicated training step."""64 inputs, targets = inputs65 with tf.GradientTape() as tape:66 outputs = self.model(inputs)67 loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))68 grads = tape.gradient(loss, self.model.variables)69 self.optimizer.apply_gradients(zip(grads, self.model.variables))70 self.train_loss.update_state(loss)71 self.strategy.run(_replicated_step, args=(next(iterator),))72 def train_loop_end(self):73 train_loss = self.train_loss.result()74 return {75 "loss": train_loss.numpy() if self.return_numpy else train_loss,76 }77 def build_eval_dataset(self):78 return self.strategy.distribute_datasets_from_function(dataset_fn)79 def eval_begin(self):80 self.eval_loss.reset_states()81 def eval_step(self, iterator):82 def _replicated_step(inputs):83 """Replicated evaluation step."""84 inputs, targets = inputs85 outputs = self.model(inputs)86 loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))87 self.eval_loss.update_state(loss)88 self.strategy.run(_replicated_step, args=(next(iterator),))89 def eval_end(self):90 eval_loss = self.eval_loss.result()91 return {92 "eval_loss": eval_loss.numpy() if self.return_numpy else eval_loss,93 }94class TestEvaluator(standard_runner.StandardEvaluator):95 """Implements the training and evaluation APIs for the test model."""96 def __init__(self):97 self.strategy = tf.distribute.get_strategy()98 self.model = create_model()99 eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)100 standard_runner.StandardEvaluator.__init__(self, eval_dataset)101 def eval_reduce(self, state, output):102 state.append(output)103 return state104 def eval_begin(self):105 return []106 def eval_step(self, iterator):107 def _replicated_step(inputs):108 """Replicated evaluation step."""109 inputs, targets = inputs110 outputs = self.model(inputs)111 loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))112 return loss113 per_replica_losses = self.strategy.run(114 _replicated_step, args=(next(iterator),))115 mean_loss = self.strategy.reduce(116 tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)117 return mean_loss118 def eval_end(self, outputs):119 return {120 "eval_loss": tf.reduce_mean(outputs),121 }122class TestEvaluatorNoOutput(runner.AbstractEvaluator):123 def evaluate(self, num_steps):124 pass125class TestEvaluatorWithNestedSummary(standard_runner.StandardEvaluator):126 """Implements the training and evaluation APIs for the test model."""127 def __init__(self):128 self.strategy = tf.distribute.get_strategy()129 self.model = create_model()130 dataset = self.strategy.distribute_datasets_from_function(dataset_fn)131 dataset2 = self.strategy.distribute_datasets_from_function(dataset_fn)132 self.loss = tf.keras.metrics.Mean("loss", dtype=tf.float32)133 self.accuracy = tf.keras.metrics.CategoricalAccuracy(134 "accuracy", dtype=tf.float32)135 self.loss2 = tf.keras.metrics.Mean("loss", dtype=tf.float32)136 self.accuracy2 = tf.keras.metrics.CategoricalAccuracy(137 "accuracy", dtype=tf.float32)138 standard_runner.StandardEvaluator.__init__(139 self, eval_dataset={140 "dataset": dataset,141 "dataset2": dataset2142 })143 def eval_step(self, iterator):144 def _replicated_step(loss, accuracy, inputs):145 """Replicated evaluation step."""146 inputs, targets = inputs147 outputs = self.model(inputs)148 loss.update_state(tf.keras.losses.MSE(targets, outputs))149 accuracy.update_state(targets, outputs)150 self.strategy.run(151 lambda inputs: _replicated_step(self.loss, self.accuracy, inputs),152 args=(next(iterator["dataset"]),))153 self.strategy.run(154 lambda inputs: _replicated_step(self.loss2, self.accuracy2, inputs),155 args=(next(iterator["dataset2"]),))156 def eval_end(self):157 return {158 "dataset": {159 "loss": self.loss.result(),160 "accuracy": self.accuracy.result()161 },162 "dataset2": {163 "loss": self.loss2.result(),164 "accuracy": self.accuracy2.result()165 },166 }167class TestTrainerWithSummaries(standard_runner.StandardTrainer):168 """A Trainer model with summaries for testing purposes."""169 def __init__(self):170 self.strategy = tf.distribute.get_strategy()171 self.model = create_model()172 self.optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)173 self.global_step = self.optimizer.iterations174 self.train_loss = tf.keras.metrics.Mean("train_loss", dtype=tf.float32)175 train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)176 standard_runner.StandardTrainer.__init__(177 self,178 train_dataset,179 options=standard_runner.StandardTrainerOptions(180 use_tpu_summary_optimization=True))181 def build_train_dataset(self):182 return self.strategy.distribute_datasets_from_function(dataset_fn)183 def train_step(self, iterator):184 def _replicated_step(inputs):185 """Replicated training step."""186 inputs, targets = inputs187 with tf.GradientTape() as tape:188 outputs = self.model(inputs)189 loss = tf.reduce_mean(tf.keras.losses.MSE(targets, outputs))190 tf.summary.scalar("loss", loss)191 grads = tape.gradient(loss, self.model.variables)192 self.optimizer.apply_gradients(zip(grads, self.model.variables))193 self.train_loss.update_state(loss)194 self.strategy.run(_replicated_step, args=(next(iterator),))195class ControllerTest(tf.test.TestCase, parameterized.TestCase):196 def setUp(self):197 super().setUp()198 self.model_dir = self.get_temp_dir()199 def test_no_checkpoint(self):200 test_runner = TestRunner()201 # No checkpoint manager and no strategy.202 test_controller = controller.Controller(203 trainer=test_runner,204 evaluator=test_runner,205 global_step=test_runner.global_step,206 steps_per_loop=2,207 summary_dir=os.path.join(self.model_dir, "summaries/train"),208 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))209 test_controller.train_and_evaluate(210 train_steps=10, eval_steps=2, eval_interval=6)211 self.assertEqual(test_runner.global_step, 10)212 # Loss and accuracy values should be written into summaries.213 self.assertNotEmpty(214 tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))215 self.assertNotEmpty(216 summaries_with_matching_keyword(217 "loss", os.path.join(self.model_dir, "summaries/train")))218 self.assertNotEmpty(219 tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))220 self.assertNotEmpty(221 summaries_with_matching_keyword(222 "eval_loss", os.path.join(self.model_dir, "summaries/eval")))223 # No checkpoint, so global step starts from 0.224 test_runner.global_step.assign(0)225 test_controller.train_and_evaluate(226 train_steps=10, eval_steps=2, eval_interval=6)227 self.assertEqual(test_runner.global_step, 10)228 def test_no_checkpoint_and_summaries(self):229 test_runner = TestRunner()230 # No checkpoint + summary directories.231 test_controller = controller.Controller(232 trainer=test_runner,233 evaluator=test_runner,234 global_step=test_runner.global_step,235 steps_per_loop=2)236 test_controller.train_and_evaluate(237 train_steps=10, eval_steps=2, eval_interval=6)238 self.assertEqual(test_runner.global_step, 10)239 def test_has_checkpoint_no_summaries(self):240 test_runner = TestRunner()241 # Has checkpoint, but no summary directories.242 checkpoint = tf.train.Checkpoint(model=test_runner.model)243 checkpoint_manager = tf.train.CheckpointManager(244 checkpoint,245 self.model_dir,246 max_to_keep=None,247 step_counter=test_runner.global_step)248 test_controller = controller.Controller(249 trainer=test_runner,250 evaluator=test_runner,251 global_step=test_runner.global_step,252 checkpoint_manager=checkpoint_manager,253 steps_per_loop=2)254 test_controller.train_and_evaluate(255 train_steps=10, eval_steps=2, eval_interval=6)256 self.assertEqual(test_runner.global_step, 10)257 # No summaries are saved.258 self.assertEmpty(tf.io.gfile.glob(259 os.path.join(checkpoint_manager.directory, "events.*")))260 def test_has_checkpoint_eval_summary_only(self):261 test_runner = TestRunner()262 # Has checkpoint, but no summary directories.263 checkpoint = tf.train.Checkpoint(model=test_runner.model)264 checkpoint_manager = tf.train.CheckpointManager(265 checkpoint,266 self.model_dir,267 max_to_keep=None,268 step_counter=test_runner.global_step)269 test_controller = controller.Controller(270 trainer=test_runner,271 evaluator=test_runner,272 global_step=test_runner.global_step,273 checkpoint_manager=checkpoint_manager,274 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),275 steps_per_loop=2)276 test_controller.train_and_evaluate(277 train_steps=10, eval_steps=2, eval_interval=6)278 self.assertEqual(test_runner.global_step, 10)279 # Training summaries are not saved.280 self.assertEmpty(tf.io.gfile.glob(281 os.path.join(checkpoint_manager.directory, "events.*")))282 # Evaluation summaries are saved.283 self.assertNotEmpty(tf.io.gfile.glob(284 os.path.join(self.model_dir, "summaries/eval/events.*")))285 def test_restore_from_most_recent_checkpoint(self):286 test_runner = TestRunner()287 checkpoint = tf.train.Checkpoint(model=test_runner.model)288 checkpoint_manager = tf.train.CheckpointManager(289 checkpoint,290 self.model_dir,291 max_to_keep=None,292 step_counter=test_runner.global_step,293 checkpoint_interval=5)294 test_controller = controller.Controller(295 trainer=test_runner,296 global_step=test_runner.global_step,297 checkpoint_manager=checkpoint_manager,298 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),299 steps_per_loop=5)300 test_controller.train(20)301 self.assertLen(checkpoint_manager.checkpoints, 4)302 restored_path = test_controller.restore_checkpoint()303 self.assertEqual(restored_path, checkpoint_manager.checkpoints[-1])304 @parameterized.named_parameters(("return_numpy", True),305 ("return_tensor", False))306 def test_train_and_evaluate(self, return_numpy):307 test_runner = TestRunner(return_numpy=return_numpy)308 checkpoint = tf.train.Checkpoint(309 model=test_runner.model, optimizer=test_runner.optimizer)310 checkpoint_manager = tf.train.CheckpointManager(311 checkpoint,312 self.model_dir,313 max_to_keep=None,314 step_counter=test_runner.global_step,315 checkpoint_interval=10)316 test_controller = controller.Controller(317 trainer=test_runner,318 evaluator=test_runner,319 global_step=test_runner.global_step,320 steps_per_loop=2,321 summary_dir=os.path.join(self.model_dir, "summaries/train"),322 checkpoint_manager=checkpoint_manager,323 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))324 test_controller.train_and_evaluate(325 train_steps=10, eval_steps=2, eval_interval=6)326 # Checkpoints are saved.327 self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))328 # Loss and accuracy values should be written into summaries.329 self.assertNotEmpty(330 tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))331 self.assertNotEmpty(332 summaries_with_matching_keyword(333 "loss", os.path.join(self.model_dir, "summaries/train")))334 self.assertNotEmpty(335 tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))336 self.assertNotEmpty(337 summaries_with_matching_keyword(338 "eval_loss", os.path.join(self.model_dir, "summaries/eval")))339 def test_train_only(self):340 test_runner = TestRunner()341 checkpoint = tf.train.Checkpoint(342 model=test_runner.model, optimizer=test_runner.optimizer)343 checkpoint_manager = tf.train.CheckpointManager(344 checkpoint,345 self.model_dir,346 max_to_keep=None,347 step_counter=test_runner.global_step,348 checkpoint_interval=10)349 test_controller = controller.Controller(350 trainer=test_runner,351 global_step=test_runner.global_step,352 steps_per_loop=2,353 summary_dir=os.path.join(self.model_dir, "summaries/train"),354 checkpoint_manager=checkpoint_manager,355 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),356 )357 test_controller.train(steps=10)358 # Checkpoints are saved.359 self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))360 # Only train summaries are written.361 self.assertNotEmpty(362 tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))363 self.assertNotEmpty(364 summaries_with_matching_keyword(365 "loss", os.path.join(self.model_dir, "summaries/train")))366 self.assertFalse(367 tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))368 def test_evaluate_only(self):369 test_runner = TestRunner()370 checkpoint = tf.train.Checkpoint(model=test_runner.model)371 checkpoint.save(os.path.join(self.model_dir, "ckpt"))372 checkpoint_manager = tf.train.CheckpointManager(373 checkpoint,374 self.model_dir,375 max_to_keep=None,376 step_counter=test_runner.global_step)377 test_controller = controller.Controller(378 evaluator=test_runner,379 global_step=test_runner.global_step,380 checkpoint_manager=checkpoint_manager,381 summary_dir=os.path.join(self.model_dir, "summaries/train"),382 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))383 eval_results = test_controller.evaluate(steps=2)384 # Only eval summaries are written385 self.assertFalse(386 tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))387 self.assertNotEmpty(388 tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))389 self.assertNotEmpty(390 summaries_with_matching_keyword(391 "eval_loss", os.path.join(self.model_dir, "summaries/eval")))392 self.assertIn("eval_loss", eval_results)393 # Tests continuous eval with timeout and timeout_fn.394 done_file = os.path.join(self.model_dir, "summaries/eval/Done")395 def timeout_fn():396 with tf.io.gfile.GFile(done_file, "w") as f:397 f.write("DONE")398 return True399 test_controller = controller.Controller(400 evaluator=test_runner,401 global_step=test_runner.global_step,402 checkpoint_manager=checkpoint_manager,403 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))404 test_controller.evaluate_continuously(405 timeout=1, timeout_fn=timeout_fn, steps=2)406 self.assertNotEmpty(tf.io.gfile.glob(done_file))407 def test_no_eval_steps(self):408 test_runner = TestRunner()409 checkpoint = tf.train.Checkpoint(model=test_runner.model)410 checkpoint.save(os.path.join(self.model_dir, "ckpt"))411 checkpoint_manager = tf.train.CheckpointManager(412 checkpoint,413 self.model_dir,414 max_to_keep=None,415 step_counter=test_runner.global_step)416 test_controller = controller.Controller(417 evaluator=test_runner,418 global_step=test_runner.global_step,419 checkpoint_manager=checkpoint_manager)420 test_controller.evaluate()421 def test_already_trained_model(self):422 test_runner = TestRunner()423 test_runner.global_step.assign(10)424 checkpoint = tf.train.Checkpoint(425 model=test_runner.model, optimizer=test_runner.optimizer)426 checkpoint_manager = tf.train.CheckpointManager(427 checkpoint,428 self.model_dir,429 max_to_keep=None,430 step_counter=test_runner.global_step,431 checkpoint_interval=10)432 test_controller = controller.Controller(433 trainer=test_runner,434 global_step=test_runner.global_step,435 steps_per_loop=2,436 checkpoint_manager=checkpoint_manager)437 # `global_step` is already `train_steps`.438 test_controller.train(steps=10)439 def test_summaries_inside_train_fn(self):440 test_runner = TestTrainerWithSummaries()441 checkpoint = tf.train.Checkpoint(442 model=test_runner.model, optimizer=test_runner.optimizer)443 checkpoint_manager = tf.train.CheckpointManager(444 checkpoint,445 self.model_dir,446 max_to_keep=None,447 step_counter=test_runner.global_step)448 test_controller = controller.Controller(449 trainer=test_runner,450 global_step=test_runner.global_step,451 steps_per_loop=2,452 summary_dir=os.path.join(self.model_dir, "summaries/train"),453 summary_interval=2,454 checkpoint_manager=checkpoint_manager,455 )456 test_controller.train(steps=10)457 # Checkpoints are saved.458 self.assertEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))459 # Only train summaries are written.460 self.assertNotEmpty(461 tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))462 self.assertNotEmpty(463 summaries_with_matching_keyword(464 "loss", os.path.join(self.model_dir, "summaries/train")))465 self.assertFalse(466 tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))467 def test_train_and_evaluate_with_same_summary_dir(self):468 test_runner = TestRunner()469 checkpoint = tf.train.Checkpoint(470 model=test_runner.model, optimizer=test_runner.optimizer)471 checkpoint_manager = tf.train.CheckpointManager(472 checkpoint,473 self.model_dir,474 max_to_keep=None,475 step_counter=test_runner.global_step)476 test_controller = controller.Controller(477 trainer=test_runner,478 evaluator=test_runner,479 global_step=test_runner.global_step,480 steps_per_loop=2,481 summary_dir=os.path.join(self.model_dir, "summaries"),482 checkpoint_manager=checkpoint_manager,483 eval_summary_dir=os.path.join(self.model_dir, "summaries"))484 test_controller.train_and_evaluate(485 train_steps=10, eval_steps=2, eval_interval=6)486 # Loss and accuracy values should be written into summaries.487 self.assertNotEmpty(488 tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries")))489 self.assertNotEmpty(490 summaries_with_matching_keyword(491 "loss", os.path.join(self.model_dir, "summaries")))492 self.assertNotEmpty(493 summaries_with_matching_keyword(494 "eval_loss", os.path.join(self.model_dir, "summaries")))495 def test_early_stop_on_eval_loss(self):496 test_runner = TestRunner()497 class EarlyStopController(controller.Controller):498 """A subclass of Controller that supports early stopping."""499 def train_and_evaluate(self,500 train_steps: int = None,501 eval_steps: int = None,502 eval_interval: int = None):503 while self.global_step.numpy() < train_steps:504 interval = min(train_steps - self.global_step.numpy(), eval_interval)505 num_steps = self.global_step.numpy() + interval506 self.train(steps=num_steps, checkpoint_at_completion=False)507 self.evaluate(steps=eval_steps)508 # Early stop condition.509 if test_runner.eval_loss.result() < 0.1:510 logging.info(511 "Training early stopped as eval_loss %s is less than 0.1",512 test_runner.eval_loss.result())513 return514 checkpoint = tf.train.Checkpoint(515 model=test_runner.model, optimizer=test_runner.optimizer)516 checkpoint_manager = tf.train.CheckpointManager(517 checkpoint,518 self.model_dir,519 max_to_keep=None,520 step_counter=test_runner.global_step,521 checkpoint_interval=10)522 test_controller = EarlyStopController(523 trainer=test_runner,524 evaluator=test_runner,525 global_step=test_runner.global_step,526 steps_per_loop=2,527 checkpoint_manager=checkpoint_manager)528 test_controller.train_and_evaluate(529 train_steps=10, eval_steps=6, eval_interval=2)530 self.assertLess(test_runner.global_step, 10)531 def test_evaluate_with_loss_output(self):532 test_evaluator = TestEvaluator()533 checkpoint = tf.train.Checkpoint(model=test_evaluator.model)534 checkpoint.save(os.path.join(self.model_dir, "ckpt"))535 checkpoint_manager = tf.train.CheckpointManager(536 checkpoint, self.model_dir, max_to_keep=None)537 test_controller = controller.Controller(538 evaluator=test_evaluator,539 global_step=tf.Variable(0, dtype=tf.int64),540 checkpoint_manager=checkpoint_manager,541 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))542 test_controller.evaluate(steps=5)543 # Only eval summaries are written544 self.assertNotEmpty(545 tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))546 self.assertNotEmpty(547 summaries_with_matching_keyword(548 "eval_loss", os.path.join(self.model_dir, "summaries/eval")))549 def test_evaluate_with_no_output(self):550 test_controller = controller.Controller(551 evaluator=TestEvaluatorNoOutput(),552 global_step=tf.Variable(0, dtype=tf.int64),553 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))554 self.assertEqual(test_controller.evaluate(steps=5), {})555 def test_train_and_evaluate_reset_datasets(self):556 test_runner = TestRunner()557 test_controller = controller.Controller(558 trainer=test_runner,559 evaluator=test_runner,560 global_step=test_runner.global_step,561 steps_per_loop=2)562 test_controller.train_and_evaluate(563 train_steps=10, eval_steps=2, eval_interval=6)564 train_dataset = (565 test_runner.strategy.distribute_datasets_from_function(dataset_fn))566 eval_dataset = (567 test_runner.strategy.distribute_datasets_from_function(dataset_fn))568 test_runner.train_dataset = train_dataset569 test_runner.eval_dataset = eval_dataset570 test_controller.train_and_evaluate(571 train_steps=10, eval_steps=2, eval_interval=6)572 def test_eval_and_checkpoint_interval(self):573 test_runner = TestRunner()574 checkpoint = tf.train.Checkpoint(575 model=test_runner.model, optimizer=test_runner.optimizer)576 checkpoint_manager = tf.train.CheckpointManager(577 checkpoint,578 self.model_dir,579 max_to_keep=None,580 step_counter=test_runner.global_step,581 checkpoint_interval=5)582 test_controller = controller.Controller(583 trainer=test_runner,584 evaluator=test_runner,585 global_step=test_runner.global_step,586 steps_per_loop=10,587 checkpoint_manager=checkpoint_manager,588 summary_dir=self.model_dir)589 test_controller.train_and_evaluate(590 train_steps=10, eval_steps=2, eval_interval=5)591 # Expect 3 checkpoints to be saved at step: 5, 10.592 self.assertLen(593 tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt-*.data*")), 2)594 # Expect evaluation is performed 2 times at step: 5, 10.595 self.assertLen(596 summaries_with_matching_keyword("eval_loss", self.model_dir), 2)597 def test_evaluate_with_nested_summaries(self):598 test_evaluator = TestEvaluatorWithNestedSummary()599 test_controller = controller.Controller(600 evaluator=test_evaluator,601 global_step=tf.Variable(0, dtype=tf.int64),602 eval_summary_dir=self.model_dir)603 test_controller.evaluate(steps=5)604 self.assertNotEmpty(605 tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset")))606 self.assertNotEmpty(607 summaries_with_matching_keyword(608 "loss", os.path.join(self.model_dir, "dataset")))609 self.assertNotEmpty(610 summaries_with_matching_keyword(611 "accuracy", os.path.join(self.model_dir, "dataset")))612 self.assertNotEmpty(613 tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset2")))614 self.assertNotEmpty(615 summaries_with_matching_keyword(616 "loss", os.path.join(self.model_dir, "dataset2")))617 self.assertNotEmpty(618 summaries_with_matching_keyword(619 "accuracy", os.path.join(self.model_dir, "dataset2")))620 def test_actions(self):621 test_runner = TestRunner()622 checkpoint = tf.train.Checkpoint(623 model=test_runner.model, optimizer=test_runner.optimizer)624 checkpoint_manager = tf.train.CheckpointManager(625 checkpoint,626 self.model_dir,627 max_to_keep=None,628 step_counter=test_runner.global_step,629 checkpoint_interval=10)630 class OutputRecorderAction:631 """Simple `Action` that just saves the outputs passed to `__call__`."""632 def __init__(self):633 self.outputs = []634 def __call__(self, output):635 self.outputs.append(output)636 train_output_recorder = OutputRecorderAction()637 eval_output_recorder = OutputRecorderAction()638 test_controller = controller.Controller(639 trainer=test_runner,640 evaluator=test_runner,641 train_actions=[train_output_recorder],642 eval_actions=[eval_output_recorder],643 global_step=test_runner.global_step,644 steps_per_loop=2,645 summary_dir=os.path.join(self.model_dir, "summaries/train"),646 checkpoint_manager=checkpoint_manager,647 eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))648 test_controller.train_and_evaluate(649 train_steps=10, eval_steps=2, eval_interval=6)650 self.assertLen(train_output_recorder.outputs, 5)651 for output in train_output_recorder.outputs:652 self.assertIn("loss", output)653 self.assertGreaterEqual(output["loss"], 0)654 self.assertLen(eval_output_recorder.outputs, 2)655 for output in eval_output_recorder.outputs:656 self.assertIn("eval_loss", output)657 self.assertGreaterEqual(output["eval_loss"], 0)658if __name__ == "__main__":...
Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!