Best Python code snippet using ATX
saver_test.py
Source:saver_test.py
...90 "v0": v0,91 "v1": v1,92 "v2": v2.saveable93 }, restore_sequentially=True)94 val = save.save(sess, save_path)95 self.assertTrue(isinstance(val, six.string_types))96 self.assertEqual(save_path, val)97 # Start a second session. In that session the parameter nodes98 # have not been initialized either.99 with self.test_session(graph=ops_lib.Graph()) as sess:100 v0 = variable_op(-1.0, name="v0")101 v1 = variable_op(-1.0, name="v1")102 v2 = saver_test_utils.CheckpointedOp(name="v2")103 # Assert that the variables are not initialized.104 if context.in_graph_mode():105 self.assertEqual(106 len(variables.report_uninitialized_variables().eval()), 2)107 self.assertEqual(0, len(v2.keys().eval()))108 self.assertEqual(0, len(v2.values().eval()))109 # Restore the saved values in the parameter nodes.110 save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})111 save.restore(sess, save_path)112 # Check that the parameter nodes have been restored.113 self.assertEqual(10.0, self.evaluate(v0))114 self.assertEqual(20.0, self.evaluate(v1))115 self.assertEqual(b"k1", self.evaluate(v2.keys()))116 self.assertEqual(30.0, self.evaluate(v2.values()))117 # Build another graph with 2 nodes, initialized118 # differently, and a Restore node for them.119 with self.test_session(graph=ops_lib.Graph()) as sess:120 v0_2 = variable_op(1000.0, name="v0")121 v1_2 = variable_op(2000.0, name="v1")122 v2_2 = saver_test_utils.CheckpointedOp(name="v2")123 v2_init = v2_2.insert("k1000", 3000.0)124 # Check that the parameter nodes have been initialized.125 if context.in_graph_mode():126 init_all_op = [variables.global_variables_initializer(), v2_init]127 self.evaluate(init_all_op)128 # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty129 # table as it claims in eager mode?130 self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))131 self.assertEqual(3000.0, self.evaluate(v2_2.values()))132 self.assertEqual(1000.0, self.evaluate(v0_2))133 self.assertEqual(2000.0, self.evaluate(v1_2))134 # Restore the values saved earlier in the parameter nodes.135 save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})136 save2.restore(sess, save_path)137 # Check that the parameter nodes have been restored.138 self.assertEqual(10.0, self.evaluate(v0_2))139 self.assertEqual(20.0, self.evaluate(v1_2))140 self.assertEqual(b"k1", self.evaluate(v2_2.keys()))141 self.assertEqual(30.0, self.evaluate(v2_2.values()))142 def testBasic(self):143 self.basicSaveRestore(variables.Variable)144 @test_util.run_in_graph_and_eager_modes()145 def testResourceBasic(self):146 self.basicSaveRestore(resource_variable_ops.ResourceVariable)147 def testEagerBasic(self):148 with context.eager_mode():149 ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt")150 v1 = resource_variable_ops.ResourceVariable(3.14, name="v1")151 v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2")152 save = saver_module.Saver([v1, v2])153 save.save(None, ckpt_prefix)154 v1.assign(0.0)155 v2.assign([0, 0])156 self.assertNear(0.0, self.evaluate(v1), 1e-5)157 self.assertAllEqual([0, 0], self.evaluate(v2))158 save.restore(None, ckpt_prefix)159 self.assertNear(3.14, self.evaluate(v1), 1e-5)160 self.assertAllEqual([1, 2], self.evaluate(v2))161 def testEagerGraphCompatibility(self):162 # Save from graph mode and restore from eager mode.163 graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt")164 with context.graph_mode():165 with self.test_session(graph=ops_lib.Graph()) as sess:166 # Create a graph model and save the checkpoint.167 w1 = resource_variable_ops.ResourceVariable(1.0, name="w1")168 w2 = resource_variable_ops.ResourceVariable(2.0, name="w2")169 graph_saver = saver_module.Saver([w1, w2])170 sess.run(variables.global_variables_initializer())171 graph_saver.save(sess, graph_ckpt_prefix)172 with context.eager_mode():173 ops_lib._default_graph_stack.reset() # pylint: disable=protected-access174 ops_lib.reset_default_graph()175 w1 = resource_variable_ops.ResourceVariable(0.0, name="w1")176 w2 = resource_variable_ops.ResourceVariable(0.0, name="w2")177 graph_saver = saver_module.Saver([w1, w2])178 graph_saver.restore(None, graph_ckpt_prefix)179 self.assertAllEqual(self.evaluate(w1), 1.0)180 self.assertAllEqual(self.evaluate(w2), 2.0)181 # Save from eager mode and restore from graph mode.182 eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt")183 with context.eager_mode():184 ops_lib._default_graph_stack.reset() # pylint: disable=protected-access185 ops_lib.reset_default_graph()186 w3 = resource_variable_ops.ResourceVariable(3.0, name="w3")187 w4 = resource_variable_ops.ResourceVariable(4.0, name="w4")188 graph_saver = saver_module.Saver([w3, w4])189 graph_saver.save(None, eager_ckpt_prefix)190 with context.graph_mode():191 with self.test_session(graph=ops_lib.Graph()) as sess:192 w3 = resource_variable_ops.ResourceVariable(0.0, name="w3")193 w4 = resource_variable_ops.ResourceVariable(0.0, name="w4")194 graph_saver = saver_module.Saver([w3, w4])195 sess.run(variables.global_variables_initializer())196 graph_saver.restore(sess, eager_ckpt_prefix)197 self.assertAllEqual(w3.eval(), 3.0)198 self.assertAllEqual(w4.eval(), 4.0)199 @test_util.run_in_graph_and_eager_modes()200 def testResourceSaveRestoreCachingDevice(self):201 save_path = os.path.join(self.get_temp_dir(), "resource_cache")202 with self.test_session(graph=ops_lib.Graph()) as sess:203 v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0",204 name="v")205 if context.in_graph_mode():206 self.evaluate(variables.global_variables_initializer())207 else:208 sess = None209 save = saver_module.Saver([v])210 save.save(sess, save_path)211 save2 = saver_module.Saver([v])212 save2.restore(sess, save_path)213 self.assertEquals(self.evaluate(v), [1])214 def testSaveCopyRestoreWithSaveRelativePaths(self):215 """Save, copy checkpoint dir and restore from copied dir.216 This only works for save_relative_paths=True.217 """218 save_dir1 = os.path.join(self.get_temp_dir(), "save_dir1")219 os.mkdir(save_dir1)220 save_path1 = os.path.join(save_dir1, "save_copy_restore")221 # Build a graph with 2 parameter nodes, and Save and222 # Restore nodes for them.223 v0 = variables.Variable(10.0, name="v0")224 v1 = variables.Variable(20.0, name="v1")225 v2 = saver_test_utils.CheckpointedOp(name="v2")226 v2_init = v2.insert("k1", 30.0)227 save = saver_module.Saver(228 var_list={229 "v0": v0,230 "v1": v1,231 "v2": v2.saveable},232 restore_sequentially=True,233 save_relative_paths=True)234 init_all_op = [variables.global_variables_initializer(), v2_init]235 with self.test_session() as sess:236 # Initialize all variables237 sess.run(init_all_op)238 # Check that the parameter nodes have been initialized.239 self.assertEqual(10.0, v0.eval())240 self.assertEqual(20.0, v1.eval())241 self.assertEqual(b"k1", v2.keys().eval())242 self.assertEqual(30.0, v2.values().eval())243 # Save the initialized values in the file at "save_path"244 val = save.save(sess, save_path1)245 self.assertTrue(isinstance(val, six.string_types))246 self.assertEqual(save_path1, val)247 self.assertEqual(saver_module.latest_checkpoint(save_dir1), save_path1)248 save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2")249 os.renames(save_dir1, save_dir2)250 save_path2 = os.path.join(save_dir2, "save_copy_restore")251 self.assertEqual(saver_module.latest_checkpoint(save_dir2), save_path2)252 # Start a second session. In that session the parameter nodes253 # have not been initialized either.254 with self.test_session() as sess:255 v0 = variables.Variable(-1.0, name="v0")256 v1 = variables.Variable(-1.0, name="v1")257 v2 = saver_test_utils.CheckpointedOp(name="v2")258 save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})259 # Assert that the variables are not initialized.260 self.assertEqual(261 len(variables.report_uninitialized_variables().eval()), 2)262 self.assertEqual(0, len(v2.keys().eval()))263 self.assertEqual(0, len(v2.values().eval()))264 # Restore the saved values in the parameter nodes.265 save.restore(sess, save_path2)266 # Check that the parameter nodes have been restored.267 self.assertEqual(10.0, v0.eval())268 self.assertEqual(20.0, v1.eval())269 self.assertEqual(b"k1", v2.keys().eval())270 self.assertEqual(30.0, v2.values().eval())271 def testFilenameTensor(self):272 v0 = variables.Variable(0, name="v0")273 filename = b"somerandomfilename"274 save = saver_module.Saver({"v0": v0}, filename=filename)275 with self.test_session() as sess:276 tensor = sess.graph.get_tensor_by_name(277 save.saver_def.filename_tensor_name)278 self.assertEqual(sess.run(tensor), filename)279 def testInvalidPath(self):280 v0 = variables.Variable(0, name="v0")281 for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):282 with self.test_session() as sess:283 save = saver_module.Saver({"v0": v0}, write_version=ver)284 with self.assertRaisesRegexp(errors.NotFoundError,285 "Failed to find any matching files for"):286 save.restore(sess, "invalid path")287 def testInt64(self):288 save_path = os.path.join(self.get_temp_dir(), "int64")289 with self.test_session() as sess:290 # Build a graph with 1 node, and save and restore for them.291 v = variables.Variable(np.int64(15), name="v")292 save = saver_module.Saver({"v": v}, restore_sequentially=True)293 variables.global_variables_initializer().run()294 # Save the initialized values in the file at "save_path"295 val = save.save(sess, save_path)296 self.assertTrue(isinstance(val, six.string_types))297 self.assertEqual(save_path, val)298 with self.test_session() as sess:299 v = variables.Variable(np.int64(-1), name="v")300 save = saver_module.Saver({"v": v})301 with self.assertRaisesWithPredicateMatch(302 errors_impl.OpError, lambda e: "uninitialized value v" in e.message):303 sess.run(v)304 # Restore the saved values in the parameter nodes.305 save.restore(sess, save_path)306 # Check that the parameter nodes have been restored.307 self.assertEqual(np.int64(15), v.eval())308 def testSomeErrors(self):309 with ops_lib.Graph().as_default():310 v0 = variables.Variable([10.0], name="v0")311 v1 = variables.Variable([20.0], name="v1")312 v2 = variables.Variable([20.0], name="v2")313 v2._set_save_slice_info(314 variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))315 # By default the name used for "v2" will be "v1" and raise an error.316 with self.assertRaisesRegexp(ValueError, "same name: v1"):317 saver_module.Saver([v0, v1, v2])318 # The names are different and will work.319 saver_module.Saver({"vee1": v1, "other": [v2]})320 # Partitioned variables also cause name conflicts.321 p_v1 = variable_scope.get_variable(322 "p_v1",323 shape=[4, 5],324 partitioner=partitioned_variables.fixed_size_partitioner(325 num_shards=2))326 p_v2 = variable_scope.get_variable(327 "p_v2",328 shape=[4, 5],329 partitioner=partitioned_variables.fixed_size_partitioner(330 num_shards=2))331 p_v2._name = "p_v1"332 with self.assertRaisesRegexp(ValueError, "same name: p_v1"):333 saver_module.Saver([p_v1, p_v2])334 def testSameName(self):335 with ops_lib.Graph().as_default():336 v0 = variables.Variable([10.0], name="v0")337 v2 = saver_test_utils.CheckpointedOp(name="v2")338 # Saving one variable under two names raises an error.339 with self.assertRaisesRegexp(340 ValueError, "The same saveable will be restored with two names: v0"):341 saver_module.Saver({"v0": v0, "v0too": v0})342 # Ditto for custom saveables.343 with self.assertRaisesRegexp(344 ValueError, "The same saveable will be restored with two names: v2"):345 saver_module.Saver({"v2": v2.saveable, "v2too": v2.saveable})346 # Verify non-duplicate names work.347 saver_module.Saver({"v0": v0, "v2": v2.saveable})348 def testBasicsWithListOfVariables(self):349 save_path = os.path.join(self.get_temp_dir(), "basics_with_list")350 with self.test_session(graph=ops_lib.Graph()) as sess:351 # Build a graph with 2 parameter nodes, and Save and352 # Restore nodes for them.353 v0 = variables.Variable(10.0, name="v0")354 v1 = variables.Variable(20.0, name="v1")355 v2 = saver_test_utils.CheckpointedOp(name="v2")356 v2_init = v2.insert("k1", 30.0)357 save = saver_module.Saver([v0, v1, v2.saveable])358 variables.global_variables_initializer().run()359 v2_init.run()360 # Check that the parameter nodes have been initialized.361 self.assertEqual(10.0, v0.eval())362 self.assertEqual(20.0, v1.eval())363 self.assertEqual(b"k1", v2.keys().eval())364 self.assertEqual(30.0, v2.values().eval())365 # Save the initialized values in the file at "save_path"366 val = save.save(sess, save_path)367 self.assertTrue(isinstance(val, six.string_types))368 self.assertEqual(save_path, val)369 # Start a second session. In that session the variables370 # have not been initialized either.371 with self.test_session(graph=ops_lib.Graph()) as sess:372 v0 = variables.Variable(-1.0, name="v0")373 v1 = variables.Variable(-1.0, name="v1")374 v2 = saver_test_utils.CheckpointedOp(name="v2")375 save = saver_module.Saver([v0, v1, v2.saveable])376 with self.assertRaisesWithPredicateMatch(377 errors_impl.OpError, lambda e: "uninitialized value v0" in e.message):378 sess.run(v0)379 with self.assertRaisesWithPredicateMatch(380 errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):381 sess.run(v1)382 self.assertEqual(0, len(v2.keys().eval()))383 self.assertEqual(0, len(v2.values().eval()))384 # Restore the saved values in the parameter nodes.385 save.restore(sess, save_path)386 # Check that the parameter nodes have been restored.387 self.assertEqual(10.0, v0.eval())388 self.assertEqual(20.0, v1.eval())389 self.assertEqual(b"k1", v2.keys().eval())390 self.assertEqual(30.0, v2.values().eval())391 # Build another graph with 2 nodes, initialized392 # differently, and a Restore node for them.393 with self.test_session(graph=ops_lib.Graph()) as sess:394 v0_2 = variables.Variable(1000.0, name="v0")395 v1_2 = variables.Variable(2000.0, name="v1")396 v2_2 = saver_test_utils.CheckpointedOp(name="v2")397 save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])398 v2_2.insert("k1000", 3000.0).run()399 variables.global_variables_initializer().run()400 # Check that the parameter nodes have been initialized.401 self.assertEqual(1000.0, v0_2.eval())402 self.assertEqual(2000.0, v1_2.eval())403 self.assertEqual(b"k1000", v2_2.keys().eval())404 self.assertEqual(3000.0, v2_2.values().eval())405 # Restore the values saved earlier in the parameter nodes.406 save2.restore(sess, save_path)407 # Check that the parameter nodes have been restored.408 self.assertEqual(10.0, v0_2.eval())409 self.assertEqual(20.0, v1_2.eval())410 self.assertEqual(b"k1", v2_2.keys().eval())411 self.assertEqual(30.0, v2_2.values().eval())412 def _SaveAndLoad(self, var_name, var_value, other_value, save_path):413 with self.test_session(graph=ops_lib.Graph()) as sess:414 var = resource_variable_ops.ResourceVariable(var_value, name=var_name)415 save = saver_module.Saver({var_name: var})416 if context.in_graph_mode():417 self.evaluate(var.initializer)418 val = save.save(sess, save_path)419 self.assertEqual(save_path, val)420 with self.test_session(graph=ops_lib.Graph()) as sess:421 var = resource_variable_ops.ResourceVariable(other_value, name=var_name)422 save = saver_module.Saver({var_name: var})423 save.restore(sess, save_path)424 self.assertAllClose(var_value, self.evaluate(var))425 def testCacheRereadsFile(self):426 save_path = os.path.join(self.get_temp_dir(), "cache_rereads")427 # Save and reload one Variable named "var0".428 self._SaveAndLoad("var0", 0.0, 1.0, save_path)429 # Save and reload one Variable named "var1" in the same file.430 # The cached readers should know to re-read the file.431 self._SaveAndLoad("var1", 1.1, 2.2, save_path)432 def testAllowEmpty(self):433 save_path = os.path.join(self.get_temp_dir(), "allow_empty")434 with self.test_session() as sess:435 _ = constant_op.constant(1)436 save = saver_module.Saver(allow_empty=True)437 val = save.save(sess, save_path)438 self.assertIsNone(val)439 with self.test_session() as sess:440 save = saver_module.Saver(allow_empty=True)441 save.restore(sess, save_path)442 def testGPU(self):443 if not test.is_gpu_available():444 return445 save_path = os.path.join(self.get_temp_dir(), "gpu")446 with session.Session("", graph=ops_lib.Graph()) as sess:447 with sess.graph.device(test.gpu_device_name()):448 v0_1 = variables.Variable(123.45)449 save = saver_module.Saver({"v0": v0_1})450 variables.global_variables_initializer().run()451 save.save(sess, save_path)452 with session.Session("", graph=ops_lib.Graph()) as sess:453 with sess.graph.device(test.gpu_device_name()):454 v0_2 = variables.Variable(543.21)455 save = saver_module.Saver({"v0": v0_2})456 variables.global_variables_initializer().run()457 def testVariables(self):458 save_path = os.path.join(self.get_temp_dir(), "variables")459 with session.Session("", graph=ops_lib.Graph()) as sess:460 one = variables.Variable(1.0)461 twos = variables.Variable([2.0, 2.0, 2.0])462 v2 = saver_test_utils.CheckpointedOp(name="v2")463 init = variables.global_variables_initializer()464 save = saver_module.Saver()465 init.run()466 v2.insert("k1", 3.0).run()467 save.save(sess, save_path)468 with session.Session("", graph=ops_lib.Graph()) as sess:469 one = variables.Variable(0.0)470 twos = variables.Variable([0.0, 0.0, 0.0])471 v2 = saver_test_utils.CheckpointedOp(name="v2")472 # Saver with no arg, defaults to 'all variables'.473 save = saver_module.Saver()474 save.restore(sess, save_path)475 self.assertAllClose(1.0, one.eval())476 self.assertAllClose([2.0, 2.0, 2.0], twos.eval())477 self.assertEqual(b"k1", v2.keys().eval())478 self.assertEqual(3.0, v2.values().eval())479 def testVarListShouldBeEmptyInDeferredBuild(self):480 with ops_lib.Graph().as_default():481 v = variables.Variable(1.0)482 with self.assertRaisesRegexp(ValueError, "defer_build"):483 saver_module.Saver([v], defer_build=True)484 def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self):485 save_path = os.path.join(self.get_temp_dir(), "error_deferred_build")486 with ops_lib.Graph().as_default(), session.Session() as sess:487 variables.Variable(1.0)488 saver = saver_module.Saver(defer_build=True)489 with self.assertRaisesRegexp(RuntimeError, "build"):490 saver.save(sess, save_path)491 def testDeferredBuild(self):492 save_path = os.path.join(self.get_temp_dir(), "deferred_build")493 with session.Session("", graph=ops_lib.Graph()) as sess:494 one = variables.Variable(1.0)495 save = saver_module.Saver(defer_build=True)496 # if build is not deferred, saver cannot save the `twos`.497 twos = variables.Variable([2.0, 2.0, 2.0])498 init = variables.global_variables_initializer()499 save.build()500 init.run()501 save.save(sess, save_path)502 with session.Session("", graph=ops_lib.Graph()) as sess:503 one = variables.Variable(0.0)504 twos = variables.Variable([0.0, 0.0, 0.0])505 # Saver with no arg, defaults to 'all variables'.506 save = saver_module.Saver()507 save.restore(sess, save_path)508 self.assertAllClose(1.0, one.eval())509 self.assertAllClose([2.0, 2.0, 2.0], twos.eval())510 def testReshape(self):511 save_path = os.path.join(self.get_temp_dir(), "variables_reshape")512 with session.Session("", graph=ops_lib.Graph()) as sess:513 var = variables.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])514 init = variables.global_variables_initializer()515 save = saver_module.Saver()516 init.run()517 save.save(sess, save_path)518 # Error when restoring with default reshape=False519 with session.Session("", graph=ops_lib.Graph()) as sess:520 var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])521 save = saver_module.Saver()522 with self.assertRaisesRegexp(523 errors_impl.InvalidArgumentError,524 "Assign requires shapes of both tensors to match."):525 save.restore(sess, save_path)526 # Restored to new shape with reshape=True527 with session.Session("", graph=ops_lib.Graph()) as sess:528 var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])529 save = saver_module.Saver(reshape=True)530 save.restore(sess, save_path)531 self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval())532 @test_util.run_in_graph_and_eager_modes()533 def testSaveWithGlobalStep(self, pad_step_number=False):534 save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step")535 global_step_int = 5536 # Save and reload one Variable named "var0".537 self._SaveAndLoad("var0", 0.0, 1.0, save_path)538 for use_tensor in [True, False]:539 with self.test_session(graph=ops_lib.Graph()):540 var = resource_variable_ops.ResourceVariable(1.0, name="var0")541 save = saver_module.Saver(542 {543 var._shared_name: var544 }, pad_step_number=pad_step_number)545 if context.in_graph_mode():546 self.evaluate(var.initializer)547 sess = ops_lib.get_default_session()548 else:549 sess = None550 if use_tensor:551 global_step = constant_op.constant(global_step_int)552 val = save.save(sess, save_path, global_step=global_step)553 else:554 val = save.save(sess, save_path, global_step=global_step_int)555 if pad_step_number:556 expected_save_path = "%s-%s" % (save_path,557 "{:08d}".format(global_step_int))558 else:559 expected_save_path = "%s-%d" % (save_path, global_step_int)560 self.assertEqual(expected_save_path, val)561 def testSaveWithGlobalStepWithPadding(self):562 self.testSaveWithGlobalStep(pad_step_number=True)563 def testSaveToNonexistingPath(self):564 file_io.write_string_to_file(565 os.path.join(self.get_temp_dir(), "actually_a_file"), "")566 paths = [567 os.path.join(self.get_temp_dir(), "nonexisting_dir/path"),568 os.path.join(self.get_temp_dir(), "other_nonexisting_dir/path1/path2"),569 os.path.join(self.get_temp_dir(), "actually_a_file/path"),570 ]571 for save_path in paths:572 # Build a graph with 2 parameter nodes, and Save and573 # Restore nodes for them.574 v0 = variables.Variable(10.0, name="v0")575 v1 = variables.Variable(20.0, name="v1")576 save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)577 init_all_op = variables.global_variables_initializer()578 # In the case where the parent directory doesn't exist, whether or not the579 # save succeeds or fails is implementation dependent. Therefore we allow580 # both cases.581 try:582 with self.test_session() as sess:583 # Initialize all variables584 sess.run(init_all_op)585 # Check that the parameter nodes have been initialized.586 self.assertEqual(10.0, v0.eval())587 self.assertEqual(20.0, v1.eval())588 # Save the graph.589 save.save(sess, save_path)590 with self.test_session() as sess:591 # Restore the saved values in the parameter nodes.592 save.restore(sess, save_path)593 # Check that the parameter nodes have been restored.594 self.assertEqual(10.0, v0.eval())595 self.assertEqual(20.0, v1.eval())596 except ValueError as exc:597 error_msg_template = "Parent directory of {} doesn't exist, can't save."598 self.assertEqual(error_msg_template.format(save_path), str(exc))599 def testSaveToURI(self):600 # ParseURI functions don't work on Windows yet.601 # TODO(jhseu): Remove this check when it works.602 if os.name == "nt":603 self.skipTest("Local URI support doesn't work on Windows")604 save_path = "file://" + os.path.join(self.get_temp_dir(), "uri")605 # Build a graph with 2 parameter nodes, and Save and606 # Restore nodes for them.607 v0 = variables.Variable(10.0, name="v0")608 v1 = variables.Variable(20.0, name="v1")609 save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)610 init_all_op = variables.global_variables_initializer()611 with self.test_session() as sess:612 # Initialize all variables613 sess.run(init_all_op)614 # Check that the parameter nodes have been initialized.615 self.assertEqual(10.0, v0.eval())616 self.assertEqual(20.0, v1.eval())617 save.save(sess, save_path)618class SaveRestoreShardedTest(test.TestCase):619 def _get_test_dir(self, dirname):620 test_dir = os.path.join(self.get_temp_dir(), dirname)621 gfile.MakeDirs(test_dir)622 return test_dir623 def testBasics(self):624 save_path = os.path.join(self.get_temp_dir(), "sharded_basics")625 # Build a graph with 2 parameter nodes on different devices.626 with session.Session(627 target="",628 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:629 with sess.graph.device("/cpu:0"):630 v0 = variables.Variable(10, name="v0")631 t0 = saver_test_utils.CheckpointedOp(name="t0")632 with sess.graph.device("/cpu:1"):633 v1 = variables.Variable(20, name="v1")634 t1 = saver_test_utils.CheckpointedOp(name="t1")635 save = saver_module.Saver(636 {637 "v0": v0,638 "v1": v1,639 "t0": t0.saveable,640 "t1": t1.saveable641 },642 sharded=True)643 variables.global_variables_initializer().run()644 t0.insert("k1", 30.0).run()645 t1.insert("k2", 40.0).run()646 val = save.save(sess, save_path)647 if save._write_version is saver_pb2.SaverDef.V1:648 self.assertEqual(save_path + "-?????-of-00002", val)649 else:650 self.assertEqual(save_path, val)651 meta_graph_filename = save._MetaGraphFilename(val)652 self.assertEqual(save_path + ".meta", meta_graph_filename)653 if save._write_version is saver_pb2.SaverDef.V1:654 # Restore different ops from shard 0 of the saved files.655 with session.Session(656 target="",657 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:658 with sess.graph.device("/cpu:0"):659 v0 = variables.Variable(111, name="v0")660 t0 = saver_test_utils.CheckpointedOp(name="t0")661 save = saver_module.Saver({"v0": v0, "t0": t0.saveable}, sharded=True)662 variables.global_variables_initializer().run()663 t0.insert("k11", 33.0).run()664 self.assertEqual(111, v0.eval())665 self.assertEqual(b"k11", t0.keys().eval())666 self.assertEqual(33.0, t0.values().eval())667 save.restore(sess, save_path + "-00000-of-00002")668 self.assertEqual(10, v0.eval())669 self.assertEqual(b"k1", t0.keys().eval())670 self.assertEqual(30.0, t0.values().eval())671 # Restore different ops from shard 1 of the saved files.672 with session.Session(673 target="",674 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:675 with sess.graph.device("/cpu:0"):676 v1 = variables.Variable(222)677 t1 = saver_test_utils.CheckpointedOp(name="t1")678 save = saver_module.Saver({"v1": v1, "t1": t1.saveable}, sharded=True)679 variables.global_variables_initializer().run()680 t1.insert("k22", 44.0).run()681 self.assertEqual(222, v1.eval())682 self.assertEqual(b"k22", t1.keys().eval())683 self.assertEqual(44.0, t1.values().eval())684 save.restore(sess, save_path + "-00001-of-00002")685 self.assertEqual(20, v1.eval())686 self.assertEqual(b"k2", t1.keys().eval())687 self.assertEqual(40.0, t1.values().eval())688 # Now try a restore with the sharded filename.689 with session.Session(690 target="",691 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:692 with sess.graph.device("/cpu:0"):693 v0 = variables.Variable(111, name="v0")694 t0 = saver_test_utils.CheckpointedOp(name="t0")695 with sess.graph.device("/cpu:1"):696 v1 = variables.Variable(222, name="v1")697 t1 = saver_test_utils.CheckpointedOp(name="t1")698 save = saver_module.Saver(699 {700 "v0": v0,701 "v1": v1,702 "t0": t0.saveable,703 "t1": t1.saveable704 },705 sharded=True)706 variables.global_variables_initializer().run()707 t0.insert("k11", 33.0).run()708 t1.insert("k22", 44.0).run()709 self.assertEqual(111, v0.eval())710 self.assertEqual(222, v1.eval())711 self.assertEqual(b"k11", t0.keys().eval())712 self.assertEqual(33.0, t0.values().eval())713 self.assertEqual(b"k22", t1.keys().eval())714 self.assertEqual(44.0, t1.values().eval())715 save_path = os.path.join(self.get_temp_dir(), "sharded_basics")716 if save._write_version is saver_pb2.SaverDef.V1:717 save.restore(sess, save_path + "-?????-of-?????")718 else:719 save.restore(sess, save_path)720 self.assertEqual(10, v0.eval())721 self.assertEqual(20, v1.eval())722 self.assertEqual(b"k1", t0.keys().eval())723 self.assertEqual(30.0, t0.values().eval())724 self.assertEqual(b"k2", t1.keys().eval())725 self.assertEqual(40.0, t1.values().eval())726 if save._write_version is saver_pb2.SaverDef.V1:727 self.assertEqual(728 saver_module.latest_checkpoint(self.get_temp_dir()),729 os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002"))730 else:731 self.assertEqual(732 saver_module.latest_checkpoint(self.get_temp_dir()),733 os.path.join(self.get_temp_dir(), "sharded_basics"))734 def testSaverDef(self):735 with self.test_session():736 v0 = variables.Variable(123, name="v0")737 save = saver_module.Saver({"v0": v0}, sharded=True)738 sd = save.as_saver_def()739 self.assertTrue(sd.sharded)740 def _testPartitionedVariables(self, use_resource):741 var_full_shape = [10, 3]742 # Allows save/restore mechanism to work w/ different slicings.743 var_name = "my_var"744 saved_dir = self._get_test_dir("partitioned_variables")745 saved_path = os.path.join(saved_dir, "ckpt")746 call_saver_with_dict = False # updated by test loop below747 def _save(slices=None, partitioner=None):748 with self.test_session(graph=ops_lib.Graph()) as sess:749 # Calls .eval() to return the ndarray that makes up the full variable.750 rnd = random_ops.random_uniform(var_full_shape).eval()751 if slices:752 assert not partitioner753 # TODO(apassos): make create_partitioned_variables take use_resource754 # option to make this test passable without creating a named755 # variable_scope.756 vs = partitioned_variables.create_partitioned_variables(757 var_full_shape, slices, rnd, name=var_name)758 elif partitioner:759 vs = [760 variable_scope.get_variable(761 var_name,762 shape=var_full_shape,763 initializer=rnd,764 partitioner=partitioner,765 use_resource=use_resource)766 ]767 else:768 if use_resource:769 vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)]770 else:771 vs = [variables.Variable(rnd, name=var_name)]772 variables.global_variables_initializer().run()773 if call_saver_with_dict:774 saver = saver_module.Saver({var_name: (vs if slices else vs[0])})775 else:776 saver = saver_module.Saver(vs)777 actual_path = saver.save(sess, saved_path)778 self.assertEqual(saved_path, actual_path)779 return rnd780 def _restore(slices=None, partitioner=None):781 with self.test_session(graph=ops_lib.Graph()) as sess:782 if slices:783 assert not partitioner784 new_vs = partitioned_variables.create_partitioned_variables(785 var_full_shape,786 slices,787 array_ops.zeros(var_full_shape), # != original contents.788 name=var_name)789 elif partitioner:790 new_vs = [791 variable_scope.get_variable(792 var_name,793 shape=var_full_shape,794 initializer=array_ops.zeros(var_full_shape),795 partitioner=partitioner)796 ]797 else:798 new_vs = [799 variables.Variable(800 array_ops.zeros(801 shape=var_full_shape), # != original contents.802 name=var_name)803 ]804 variables.global_variables_initializer().run()805 if call_saver_with_dict:806 saver = saver_module.Saver({807 var_name: (new_vs if slices else new_vs[0])808 })809 else:810 saver = saver_module.Saver(new_vs)811 saver.restore(sess, saved_path)812 if partitioner:813 return new_vs[0].as_tensor().eval()814 elif slices and slices[0] != 1:815 return array_ops.concat(new_vs, 0).eval()816 elif slices and slices[1] != 1:817 return array_ops.concat(new_vs, 1).eval()818 else: # Non-sliced.819 return new_vs[0].eval()820 for call_saver_with_dict in {False, True}:821 # Save PartitionedVariable and restore into full variable.822 saved_full = _save(823 partitioner=partitioned_variables.fixed_size_partitioner(824 num_shards=2))825 restored_full = _restore()826 self.assertAllEqual(saved_full, restored_full)827 # Saves 10 horizontal parts of a partitioned variable.828 # Restores into a full variable, non-sliced.829 saved_full = _save(slices=[10, 1])830 restored_full = _restore()831 self.assertAllEqual(saved_full, restored_full)832 # Restores into a different number/orientation of slices.833 restored_full = _restore(slices=[2, 1]) # 2 horizon parts.834 self.assertAllEqual(saved_full, restored_full)835 restored_full = _restore(slices=[1, 3]) # 3 vertical parts.836 self.assertAllEqual(saved_full, restored_full)837 # Restores into a PartitionedVariable838 restored_full = _restore(839 partitioner=partitioned_variables.fixed_size_partitioner(840 num_shards=2))841 self.assertAllEqual(saved_full, restored_full)842 # Now, saves a full variable and restores in slices.843 saved_full = _save()844 restored_full = _restore(slices=[1, 3])845 self.assertAllEqual(saved_full, restored_full)846 def testPartitionedVariable(self):847 self._testPartitionedVariables(use_resource=False)848 def testPartitionedResourceVariable(self):849 self._testPartitionedVariables(use_resource=True)850class MaxToKeepTest(test.TestCase):851 def _get_test_dir(self, dirname):852 test_dir = os.path.join(self.get_temp_dir(), dirname)853 gfile.MakeDirs(test_dir)854 return test_dir855 def assertCheckpointState(self, model_checkpoint_path,856 all_model_checkpoint_paths, save_dir):857 checkpoint_state = saver_module.get_checkpoint_state(save_dir)858 self.assertEqual(checkpoint_state.model_checkpoint_path,859 model_checkpoint_path)860 self.assertEqual(checkpoint_state.all_model_checkpoint_paths,861 all_model_checkpoint_paths)862 def testNonSharded(self):863 save_dir = self._get_test_dir("max_to_keep_non_sharded")864 with self.test_session() as sess:865 v = variables.Variable(10.0, name="v")866 save = saver_module.Saver({"v": v}, max_to_keep=2)867 variables.global_variables_initializer().run()868 self.assertEqual([], save.last_checkpoints)869 s1 = save.save(sess, os.path.join(save_dir, "s1"))870 self.assertEqual([s1], save.last_checkpoints)871 self.assertTrue(saver_module.checkpoint_exists(s1))872 self.assertCheckpointState(873 model_checkpoint_path=s1,874 all_model_checkpoint_paths=[s1],875 save_dir=save_dir)876 s2 = save.save(sess, os.path.join(save_dir, "s2"))877 self.assertEqual([s1, s2], save.last_checkpoints)878 self.assertTrue(saver_module.checkpoint_exists(s1))879 self.assertTrue(saver_module.checkpoint_exists(s2))880 self.assertCheckpointState(881 model_checkpoint_path=s2,882 all_model_checkpoint_paths=[s1, s2],883 save_dir=save_dir)884 s3 = save.save(sess, os.path.join(save_dir, "s3"))885 self.assertEqual([s2, s3], save.last_checkpoints)886 self.assertFalse(saver_module.checkpoint_exists(s1))887 self.assertTrue(saver_module.checkpoint_exists(s2))888 self.assertTrue(saver_module.checkpoint_exists(s3))889 self.assertCheckpointState(890 model_checkpoint_path=s3,891 all_model_checkpoint_paths=[s2, s3],892 save_dir=save_dir)893 # Create a second helper, identical to the first.894 save2 = saver_module.Saver(saver_def=save.as_saver_def())895 save2.set_last_checkpoints(save.last_checkpoints)896 # Create a third helper, with the same configuration but no knowledge of897 # previous checkpoints.898 save3 = saver_module.Saver(saver_def=save.as_saver_def())899 # Exercise the first helper.900 # Adding s2 again (old s2 is removed first, then new s2 appended)901 s2 = save.save(sess, os.path.join(save_dir, "s2"))902 self.assertEqual([s3, s2], save.last_checkpoints)903 self.assertFalse(saver_module.checkpoint_exists(s1))904 self.assertFalse(905 saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))906 self.assertTrue(saver_module.checkpoint_exists(s3))907 self.assertTrue(908 saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))909 self.assertTrue(saver_module.checkpoint_exists(s2))910 self.assertTrue(911 saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))912 self.assertCheckpointState(913 model_checkpoint_path=s2,914 all_model_checkpoint_paths=[s3, s2],915 save_dir=save_dir)916 # Adding s1 (s3 should now be deleted as oldest in list)917 s1 = save.save(sess, os.path.join(save_dir, "s1"))918 self.assertEqual([s2, s1], save.last_checkpoints)919 self.assertFalse(saver_module.checkpoint_exists(s3))920 self.assertFalse(921 saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))922 self.assertTrue(saver_module.checkpoint_exists(s2))923 self.assertTrue(924 saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))925 self.assertTrue(saver_module.checkpoint_exists(s1))926 self.assertTrue(927 saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))928 self.assertCheckpointState(929 model_checkpoint_path=s1,930 all_model_checkpoint_paths=[s2, s1],931 save_dir=save_dir)932 # Exercise the second helper.933 # Adding s2 again (old s2 is removed first, then new s2 appended)934 s2 = save2.save(sess, os.path.join(save_dir, "s2"))935 self.assertEqual([s3, s2], save2.last_checkpoints)936 # Created by the first helper.937 self.assertTrue(saver_module.checkpoint_exists(s1))938 self.assertTrue(939 saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))940 # Deleted by the first helper.941 self.assertFalse(saver_module.checkpoint_exists(s3))942 self.assertFalse(943 saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))944 self.assertTrue(saver_module.checkpoint_exists(s2))945 self.assertTrue(946 saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))947 self.assertCheckpointState(948 model_checkpoint_path=s2,949 all_model_checkpoint_paths=[s3, s2],950 save_dir=save_dir)951 # Adding s1 (s3 should now be deleted as oldest in list)952 s1 = save2.save(sess, os.path.join(save_dir, "s1"))953 self.assertEqual([s2, s1], save2.last_checkpoints)954 self.assertFalse(saver_module.checkpoint_exists(s3))955 self.assertFalse(956 saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))957 self.assertTrue(saver_module.checkpoint_exists(s2))958 self.assertTrue(959 saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))960 self.assertTrue(saver_module.checkpoint_exists(s1))961 self.assertTrue(962 saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))963 self.assertCheckpointState(964 model_checkpoint_path=s1,965 all_model_checkpoint_paths=[s2, s1],966 save_dir=save_dir)967 # Exercise the third helper.968 # Adding s2 again (but helper is unaware of previous s2)969 s2 = save3.save(sess, os.path.join(save_dir, "s2"))970 self.assertEqual([s2], save3.last_checkpoints)971 # Created by the first helper.972 self.assertTrue(saver_module.checkpoint_exists(s1))973 self.assertTrue(974 saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))975 # Deleted by the first helper.976 self.assertFalse(saver_module.checkpoint_exists(s3))977 self.assertFalse(978 saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))979 self.assertTrue(saver_module.checkpoint_exists(s2))980 self.assertTrue(981 saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))982 # Even though the file for s1 exists, this saver isn't aware of it, which983 # is why it doesn't end up in the checkpoint state.984 self.assertCheckpointState(985 model_checkpoint_path=s2,986 all_model_checkpoint_paths=[s2],987 save_dir=save_dir)988 # Adding s1 (s3 should not be deleted because helper is unaware of it)989 s1 = save3.save(sess, os.path.join(save_dir, "s1"))990 self.assertEqual([s2, s1], save3.last_checkpoints)991 self.assertFalse(saver_module.checkpoint_exists(s3))992 self.assertFalse(993 saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))994 self.assertTrue(saver_module.checkpoint_exists(s2))995 self.assertTrue(996 saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))997 self.assertTrue(saver_module.checkpoint_exists(s1))998 self.assertTrue(999 saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))1000 self.assertCheckpointState(1001 model_checkpoint_path=s1,1002 all_model_checkpoint_paths=[s2, s1],1003 save_dir=save_dir)1004 def testSharded(self):1005 save_dir = self._get_test_dir("max_to_keep_sharded")1006 with session.Session(1007 target="",1008 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:1009 with sess.graph.device("/cpu:0"):1010 v0 = variables.Variable(111, name="v0")1011 with sess.graph.device("/cpu:1"):1012 v1 = variables.Variable(222, name="v1")1013 save = saver_module.Saver(1014 {1015 "v0": v0,1016 "v1": v11017 }, sharded=True, max_to_keep=2)1018 variables.global_variables_initializer().run()1019 self.assertEqual([], save.last_checkpoints)1020 s1 = save.save(sess, os.path.join(save_dir, "s1"))1021 self.assertEqual([s1], save.last_checkpoints)1022 if save._write_version is saver_pb2.SaverDef.V1:1023 self.assertEqual(2, len(gfile.Glob(s1)))1024 else:1025 self.assertEqual(4, len(gfile.Glob(s1 + "*")))1026 self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))1027 s2 = save.save(sess, os.path.join(save_dir, "s2"))1028 self.assertEqual([s1, s2], save.last_checkpoints)1029 if save._write_version is saver_pb2.SaverDef.V1:1030 self.assertEqual(2, len(gfile.Glob(s1)))1031 else:1032 self.assertEqual(4, len(gfile.Glob(s1 + "*")))1033 self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))1034 if save._write_version is saver_pb2.SaverDef.V1:1035 self.assertEqual(2, len(gfile.Glob(s2)))1036 else:1037 self.assertEqual(4, len(gfile.Glob(s2 + "*")))1038 self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))1039 s3 = save.save(sess, os.path.join(save_dir, "s3"))1040 self.assertEqual([s2, s3], save.last_checkpoints)1041 self.assertEqual(0, len(gfile.Glob(s1 + "*")))1042 self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))1043 if save._write_version is saver_pb2.SaverDef.V1:1044 self.assertEqual(2, len(gfile.Glob(s2)))1045 else:1046 self.assertEqual(4, len(gfile.Glob(s2 + "*")))1047 self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))1048 if save._write_version is saver_pb2.SaverDef.V1:1049 self.assertEqual(2, len(gfile.Glob(s3)))1050 else:1051 self.assertEqual(4, len(gfile.Glob(s3 + "*")))1052 self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))1053 def testNoMaxToKeep(self):1054 save_dir = self._get_test_dir("no_max_to_keep")1055 save_dir2 = self._get_test_dir("max_to_keep_0")1056 with self.test_session() as sess:1057 v = variables.Variable(10.0, name="v")1058 variables.global_variables_initializer().run()1059 # Test max_to_keep being None.1060 save = saver_module.Saver({"v": v}, max_to_keep=None)1061 self.assertEqual([], save.last_checkpoints)1062 s1 = save.save(sess, os.path.join(save_dir, "s1"))1063 self.assertEqual([], save.last_checkpoints)1064 self.assertTrue(saver_module.checkpoint_exists(s1))1065 s2 = save.save(sess, os.path.join(save_dir, "s2"))1066 self.assertEqual([], save.last_checkpoints)1067 self.assertTrue(saver_module.checkpoint_exists(s2))1068 # Test max_to_keep being 0.1069 save2 = saver_module.Saver({"v": v}, max_to_keep=0)1070 self.assertEqual([], save2.last_checkpoints)1071 s1 = save2.save(sess, os.path.join(save_dir2, "s1"))1072 self.assertEqual([], save2.last_checkpoints)1073 self.assertTrue(saver_module.checkpoint_exists(s1))1074 s2 = save2.save(sess, os.path.join(save_dir2, "s2"))1075 self.assertEqual([], save2.last_checkpoints)1076 self.assertTrue(saver_module.checkpoint_exists(s2))1077 def testNoMetaGraph(self):1078 save_dir = self._get_test_dir("no_meta_graph")1079 with self.test_session() as sess:1080 v = variables.Variable(10.0, name="v")1081 save = saver_module.Saver({"v": v})1082 variables.global_variables_initializer().run()1083 s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)1084 self.assertTrue(saver_module.checkpoint_exists(s1))1085 self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))1086class KeepCheckpointEveryNHoursTest(test.TestCase):1087 def _get_test_dir(self, dirname):1088 test_dir = os.path.join(self.get_temp_dir(), dirname)1089 gfile.MakeDirs(test_dir)1090 return test_dir1091 @test.mock.patch.object(saver_module, "time")1092 def testNonSharded(self, mock_time):1093 save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")1094 with self.test_session() as sess:1095 v = variables.Variable([10.0], name="v")1096 # Run the initializer NOW to avoid the 0.5s overhead of the first Run()1097 # call, which throws the test timing off in fastbuild mode.1098 variables.global_variables_initializer().run()1099 # Create a saver that will keep the last 2 checkpoints plus one every 0.71100 # seconds.1101 start_time = time.time()1102 mock_time.time.return_value = start_time1103 save = saver_module.Saver(1104 {1105 "v": v1106 }, max_to_keep=2, keep_checkpoint_every_n_hours=0.7 / 3600)1107 self.assertEqual([], save.last_checkpoints)1108 # Wait till 1 seconds have elapsed so s1 will be old enough to keep.1109 # sleep may return early, don't trust it.1110 mock_time.time.return_value = start_time + 1.01111 s1 = save.save(sess, os.path.join(save_dir, "s1"))1112 self.assertEqual([s1], save.last_checkpoints)1113 s2 = save.save(sess, os.path.join(save_dir, "s2"))1114 self.assertEqual([s1, s2], save.last_checkpoints)1115 # We now have 2 'last_checkpoints': [s1, s2]. The next call to Save(),1116 # would normally delete s1, because max_to_keep is 2. However, s1 is1117 # older than 0.7s so we must keep it.1118 s3 = save.save(sess, os.path.join(save_dir, "s3"))1119 self.assertEqual([s2, s3], save.last_checkpoints)1120 # s1 should still be here, we are Not checking now to reduce time1121 # variance in the test.1122 # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk. The next1123 # call to Save(), will delete s2, because max_to_keep is 2, and because1124 # we already kept the old s1. s2 is very close in time to s1 so it gets1125 # deleted.1126 s4 = save.save(sess, os.path.join(save_dir, "s4"))1127 self.assertEqual([s3, s4], save.last_checkpoints)1128 # Check that s1 is still here, but s2 is gone.1129 self.assertTrue(saver_module.checkpoint_exists(s1))1130 self.assertFalse(saver_module.checkpoint_exists(s2))1131 self.assertTrue(saver_module.checkpoint_exists(s3))1132 self.assertTrue(saver_module.checkpoint_exists(s4))1133class SaveRestoreWithVariableNameMap(test.TestCase):1134 def _testNonReshape(self, variable_op):1135 save_path = os.path.join(self.get_temp_dir(), "non_reshape")1136 with self.test_session(graph=ops_lib.Graph()) as sess:1137 # Build a graph with 2 parameter nodes, and Save and1138 # Restore nodes for them.1139 v0 = variable_op(10.0, name="v0")1140 v1 = variable_op(20.0, name="v1")1141 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})1142 self.evaluate(variables.global_variables_initializer())1143 # Check that the parameter nodes have been initialized.1144 self.assertEqual(10.0, self.evaluate(v0))1145 self.assertEqual(20.0, self.evaluate(v1))1146 # Save the initialized values in the file at "save_path"1147 # Use a variable name map to set the saved tensor names1148 val = save.save(sess, save_path)1149 self.assertTrue(isinstance(val, six.string_types))1150 self.assertEqual(save_path, val)1151 # Verify that the original names are not in the Saved file1152 save = saver_module.Saver({"v0": v0, "v1": v1})1153 with self.assertRaisesOpError("not found in checkpoint"):1154 save.restore(sess, save_path)1155 # Verify that the mapped names are present in the Saved file and can be1156 # Restored using remapped names.1157 with self.test_session(graph=ops_lib.Graph()) as sess:1158 v0 = variable_op(-1.0, name="v0")1159 v1 = variable_op(-1.0, name="v1")1160 if context.in_graph_mode():1161 with self.assertRaisesOpError("uninitialized"):1162 self.evaluate(v0)1163 with self.assertRaisesOpError("uninitialized"):1164 self.evaluate(v1)1165 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})1166 save.restore(sess, save_path)1167 # Check that the parameter nodes have been restored.1168 if context.in_graph_mode():1169 self.assertEqual(10.0, self.evaluate(v0))1170 self.assertEqual(20.0, self.evaluate(v1))1171 # Add a prefix to the node names in the current graph and Restore using1172 # remapped names.1173 with self.test_session(graph=ops_lib.Graph()) as sess:1174 v0 = variable_op(-1.0, name="restore_prefix/v0")1175 v1 = variable_op(-1.0, name="restore_prefix/v1")1176 if context.in_graph_mode():1177 with self.assertRaisesOpError("uninitialized"):1178 self.evaluate(v0)1179 with self.assertRaisesOpError("uninitialized"):1180 self.evaluate(v1)1181 # Restore the saved values in the parameter nodes.1182 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})1183 save.restore(sess, save_path)1184 # Check that the parameter nodes have been restored.1185 self.assertEqual(10.0, self.evaluate(v0))1186 self.assertEqual(20.0, self.evaluate(v1))1187 @test_util.run_in_graph_and_eager_modes()1188 def testNonReshapeResourceVariable(self):1189 self._testNonReshape(resource_variable_ops.ResourceVariable)1190 def testNonReshapeVariable(self):1191 self._testNonReshape(variables.Variable)1192class LatestCheckpointWithRelativePaths(test.TestCase):1193 @staticmethod1194 @contextlib.contextmanager1195 def tempWorkingDir(temppath):1196 cwd = os.getcwd()1197 os.chdir(temppath)1198 try:1199 yield1200 finally:1201 os.chdir(cwd)1202 @staticmethod1203 @contextlib.contextmanager1204 def tempDir():1205 tempdir = tempfile.mkdtemp()1206 try:1207 yield tempdir1208 finally:1209 shutil.rmtree(tempdir)1210 def testNameCollision(self):1211 # Make sure we have a clean directory to work in.1212 with self.tempDir() as tempdir:1213 # Jump to that directory until this test is done.1214 with self.tempWorkingDir(tempdir):1215 # Save training snapshots to a relative path.1216 traindir = "train/"1217 os.mkdir(traindir)1218 # Collides with the default name of the checkpoint state file.1219 filepath = os.path.join(traindir, "checkpoint")1220 with self.test_session() as sess:1221 unused_a = variables.Variable(0.0) # So that Saver saves something.1222 variables.global_variables_initializer().run()1223 # Should fail.1224 saver = saver_module.Saver(sharded=False)1225 with self.assertRaisesRegexp(ValueError, "collides with"):1226 saver.save(sess, filepath)1227 # Succeeds: the file will be named "checkpoint-<step>".1228 saver.save(sess, filepath, global_step=1)1229 self.assertIsNotNone(saver_module.latest_checkpoint(traindir))1230 # Succeeds: the file will be named "checkpoint-<i>-of-<n>".1231 saver = saver_module.Saver(sharded=True)1232 saver.save(sess, filepath)1233 self.assertIsNotNone(saver_module.latest_checkpoint(traindir))1234 # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".1235 saver = saver_module.Saver(sharded=True)1236 saver.save(sess, filepath, global_step=1)1237 self.assertIsNotNone(saver_module.latest_checkpoint(traindir))1238 def testRelativePath(self):1239 # Make sure we have a clean directory to work in.1240 with self.tempDir() as tempdir:1241 # Jump to that directory until this test is done.1242 with self.tempWorkingDir(tempdir):1243 # Save training snapshots to a relative path.1244 traindir = "train/"1245 os.mkdir(traindir)1246 filename = "snapshot"1247 filepath = os.path.join(traindir, filename)1248 with self.test_session() as sess:1249 # Build a simple graph.1250 v0 = variables.Variable(0.0)1251 inc = v0.assign_add(1.0)1252 save = saver_module.Saver({"v0": v0})1253 # Record a short training history.1254 variables.global_variables_initializer().run()1255 save.save(sess, filepath, global_step=0)1256 inc.eval()1257 save.save(sess, filepath, global_step=1)1258 inc.eval()1259 save.save(sess, filepath, global_step=2)1260 with self.test_session() as sess:1261 # Build a new graph with different initialization.1262 v0 = variables.Variable(-1.0)1263 # Create a new saver.1264 save = saver_module.Saver({"v0": v0})1265 variables.global_variables_initializer().run()1266 # Get the most recent checkpoint name from the training history file.1267 name = saver_module.latest_checkpoint(traindir)1268 self.assertIsNotNone(name)1269 # Restore "v0" from that checkpoint.1270 save.restore(sess, name)1271 self.assertEqual(v0.eval(), 2.0)1272class CheckpointStateTest(test.TestCase):1273 def _get_test_dir(self, dirname):1274 test_dir = os.path.join(self.get_temp_dir(), dirname)1275 gfile.MakeDirs(test_dir)1276 return test_dir1277 def testAbsPath(self):1278 save_dir = self._get_test_dir("abs_paths")1279 abs_path = os.path.join(save_dir, "model-0")1280 ckpt = saver_module.generate_checkpoint_state_proto(save_dir, abs_path)1281 self.assertEqual(ckpt.model_checkpoint_path, abs_path)1282 self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))1283 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)1284 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)1285 def testRelPath(self):1286 train_dir = "train"1287 model = os.path.join(train_dir, "model-0")1288 # model_checkpoint_path should have no "train" directory part.1289 new_rel_path = "model-0"1290 ckpt = saver_module.generate_checkpoint_state_proto(train_dir, model)1291 self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)1292 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)1293 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)1294 def testAllModelCheckpointPaths(self):1295 save_dir = self._get_test_dir("all_models_test")1296 abs_path = os.path.join(save_dir, "model-0")1297 for paths in [None, [], ["model-2"]]:1298 ckpt = saver_module.generate_checkpoint_state_proto(1299 save_dir, abs_path, all_model_checkpoint_paths=paths)1300 self.assertEqual(ckpt.model_checkpoint_path, abs_path)1301 self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))1302 self.assertEqual(1303 len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)1304 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)1305 def testUpdateCheckpointState(self):1306 save_dir = self._get_test_dir("update_checkpoint_state")1307 os.chdir(save_dir)1308 # Make a temporary train directory.1309 train_dir = "train"1310 os.mkdir(train_dir)1311 abs_path = os.path.join(save_dir, "model-0")1312 rel_path = os.path.join("train", "model-2")1313 saver_module.update_checkpoint_state(1314 train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])1315 ckpt = saver_module.get_checkpoint_state(train_dir)1316 self.assertEqual(ckpt.model_checkpoint_path, rel_path)1317 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)1318 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)1319 self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)1320 def testUpdateCheckpointStateSaveRelativePaths(self):1321 save_dir = self._get_test_dir("update_checkpoint_state")1322 os.chdir(save_dir)1323 abs_path2 = os.path.join(save_dir, "model-2")1324 rel_path2 = "model-2"1325 abs_path0 = os.path.join(save_dir, "model-0")1326 rel_path0 = "model-0"1327 saver_module._update_checkpoint_state( # pylint: disable=protected-access1328 save_dir=save_dir,1329 model_checkpoint_path=abs_path2,1330 all_model_checkpoint_paths=[rel_path0, abs_path2],1331 save_relative_paths=True)1332 # File should contain relative paths.1333 file_content = file_io.read_file_to_string(1334 os.path.join(save_dir, "checkpoint"))1335 ckpt = CheckpointState()1336 text_format.Merge(file_content, ckpt)1337 self.assertEqual(ckpt.model_checkpoint_path, rel_path2)1338 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)1339 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)1340 self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)1341 # get_checkpoint_state should return absolute paths.1342 ckpt = saver_module.get_checkpoint_state(save_dir)1343 self.assertEqual(ckpt.model_checkpoint_path, abs_path2)1344 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)1345 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)1346 self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)1347 def testCheckPointStateFailsWhenIncomplete(self):1348 save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")1349 os.chdir(save_dir)1350 ckpt_path = os.path.join(save_dir, "checkpoint")1351 ckpt_file = open(ckpt_path, "w")1352 ckpt_file.write("")1353 ckpt_file.close()1354 with self.assertRaises(ValueError):1355 saver_module.get_checkpoint_state(save_dir)1356 def testCheckPointCompletesRelativePaths(self):1357 save_dir = self._get_test_dir("checkpoint_completes_relative_paths")1358 os.chdir(save_dir)1359 ckpt_path = os.path.join(save_dir, "checkpoint")1360 ckpt_file = open(ckpt_path, "w")1361 ckpt_file.write("""1362 model_checkpoint_path: "./model.ckpt-687529"1363 all_model_checkpoint_paths: "./model.ckpt-687500"1364 all_model_checkpoint_paths: "./model.ckpt-687529"1365 """)1366 ckpt_file.close()1367 ckpt = saver_module.get_checkpoint_state(save_dir)1368 self.assertEqual(ckpt.model_checkpoint_path,1369 os.path.join(save_dir, "./model.ckpt-687529"))1370 self.assertEqual(ckpt.all_model_checkpoint_paths[0],1371 os.path.join(save_dir, "./model.ckpt-687500"))1372 self.assertEqual(ckpt.all_model_checkpoint_paths[1],1373 os.path.join(save_dir, "./model.ckpt-687529"))1374class MetaGraphTest(test.TestCase):1375 def _get_test_dir(self, dirname):1376 test_dir = os.path.join(self.get_temp_dir(), dirname)1377 gfile.MakeDirs(test_dir)1378 return test_dir1379 def testAddCollectionDef(self):1380 test_dir = self._get_test_dir("good_collection")1381 filename = os.path.join(test_dir, "metafile")1382 with self.test_session():1383 # Creates a graph.1384 v0 = variables.Variable(1.0, name="v0")1385 control_flow_ops.cond(1386 math_ops.less(v0, 10), lambda: math_ops.add(v0, 1),1387 lambda: math_ops.subtract(v0, 1))1388 control_flow_ops.while_loop(lambda i: math_ops.less(i, 10),1389 lambda i: math_ops.add(i, 1), [v0])1390 var = variables.Variable(constant_op.constant(0, dtype=dtypes.int64))1391 count_up_to = var.count_up_to(3)1392 input_queue = data_flow_ops.FIFOQueue(1393 30, dtypes.float32, shared_name="collection_queue")1394 qr = queue_runner_impl.QueueRunner(input_queue, [count_up_to])1395 variables.global_variables_initializer()1396 # Creates a saver.1397 save = saver_module.Saver({"v0": v0})1398 # Adds a set of collections.1399 ops_lib.add_to_collection("int_collection", 3)1400 ops_lib.add_to_collection("float_collection", 3.5)1401 ops_lib.add_to_collection("string_collection", "hello")1402 ops_lib.add_to_collection("variable_collection", v0)1403 # Add QueueRunners.1404 queue_runner_impl.add_queue_runner(qr)1405 # Adds user_defined proto in three formats: string, bytes and Any.1406 queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")1407 ops_lib.add_to_collection("user_defined_string_collection",1408 str(queue_runner))1409 ops_lib.add_to_collection("user_defined_bytes_collection",1410 queue_runner.SerializeToString())1411 any_buf = Any()1412 any_buf.Pack(queue_runner)1413 ops_lib.add_to_collection("user_defined_any_collection", any_buf)1414 # Generates MetaGraphDef.1415 meta_graph_def = save.export_meta_graph(filename)1416 self.assertTrue(meta_graph_def.HasField("saver_def"))1417 self.assertTrue(meta_graph_def.HasField("graph_def"))1418 self.assertTrue(meta_graph_def.HasField("meta_info_def"))1419 self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")1420 self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,1421 "")1422 collection_def = meta_graph_def.collection_def1423 self.assertEqual(len(collection_def), 12)1424 with ops_lib.Graph().as_default():1425 # Restores from MetaGraphDef.1426 new_saver = saver_module.import_meta_graph(filename)1427 # Generates a new MetaGraphDef.1428 new_meta_graph_def = new_saver.export_meta_graph()1429 # It should be the same as the original.1430 test_util.assert_meta_graph_protos_equal(1431 self, meta_graph_def, new_meta_graph_def)1432 def testAddCollectionDefFails(self):1433 with self.test_session():1434 # Creates a graph.1435 v0 = variables.Variable(10.0, name="v0")1436 # Creates a saver.1437 save = saver_module.Saver({"v0": v0})1438 # Generates MetaGraphDef.1439 meta_graph_def = meta_graph_pb2.MetaGraphDef()1440 # Verifies that collection with unsupported key will not be added.1441 ops_lib.add_to_collection(save, 3)1442 save._add_collection_def(meta_graph_def, save)1443 self.assertEqual(len(meta_graph_def.collection_def), 0)1444 # Verifies that collection where item type does not match expected1445 # type will not be added.1446 ops_lib.add_to_collection("int_collection", 3)1447 ops_lib.add_to_collection("int_collection", 3.5)1448 save._add_collection_def(meta_graph_def, "int_collection")1449 self.assertEqual(len(meta_graph_def.collection_def), 0)1450 def _testMultiSaverCollectionSave(self, test_dir):1451 filename = os.path.join(test_dir, "metafile")1452 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1453 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")1454 with self.test_session(graph=ops_lib.Graph()) as sess:1455 # Creates a graph.1456 v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")1457 v1 = variables.Variable(11.0, name="v1")1458 # Creates 2 savers.1459 saver0 = saver_module.Saver({"v0": v0}, name="saver0")1460 saver1 = saver_module.Saver({"v1": v1}, name="saver1")1461 ops_lib.add_to_collection("savers", saver0)1462 ops_lib.add_to_collection("savers", saver1)1463 variables.global_variables_initializer().run()1464 # Saves to different checkpoints.1465 saver0.save(sess, saver0_ckpt)1466 saver1.save(sess, saver1_ckpt)1467 # Generates MetaGraphDef.1468 meta_graph_def = saver_module.export_meta_graph(filename)1469 meta_graph_def0 = saver0.export_meta_graph()1470 meta_graph_def1 = saver1.export_meta_graph()1471 # Verifies that there is no saver_def in meta_graph_def.1472 self.assertFalse(meta_graph_def.HasField("saver_def"))1473 # Verifies that there is saver_def in meta_graph_def0 and 1.1474 self.assertTrue(meta_graph_def0.HasField("saver_def"))1475 self.assertTrue(meta_graph_def1.HasField("saver_def"))1476 # Verifies SAVERS is saved as bytes_list for meta_graph_def.1477 collection_def = meta_graph_def.collection_def["savers"]1478 kind = collection_def.WhichOneof("kind")1479 self.assertEqual(kind, "bytes_list")1480 # Verifies that there are 2 entries in SAVERS collection.1481 savers = getattr(collection_def, kind)1482 self.assertEqual(2, len(savers.value))1483 # Verifies SAVERS collection is saved as bytes_list for meta_graph_def0.1484 collection_def = meta_graph_def0.collection_def["savers"]1485 kind = collection_def.WhichOneof("kind")1486 self.assertEqual(kind, "bytes_list")1487 # Verifies that there are 2 entries in SAVERS collection.1488 savers = getattr(collection_def, kind)1489 self.assertEqual(2, len(savers.value))1490 def _testMultiSaverCollectionRestore(self, test_dir):1491 filename = os.path.join(test_dir, "metafile")1492 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1493 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")1494 with self.test_session(graph=ops_lib.Graph()) as sess:1495 # Imports from meta_graph.1496 saver_module.import_meta_graph(filename)1497 # Retrieves SAVERS collection. Verifies there are 2 entries.1498 savers = ops_lib.get_collection("savers")1499 self.assertEqual(2, len(savers))1500 # Retrieves saver0. Verifies that new_saver0 can restore v0, but not v1.1501 new_saver0 = savers[0]1502 new_saver0.restore(sess, saver0_ckpt)1503 v0 = sess.graph.get_tensor_by_name("v0:0")1504 v1 = sess.graph.get_tensor_by_name("v1:0")1505 self.assertAllEqual([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], v0.eval())1506 self.assertEqual([3, 2], v0.get_shape())1507 self.assertEqual([], v1.get_shape())1508 with self.assertRaisesWithPredicateMatch(1509 errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):1510 sess.run(v1)1511 # Retrieves saver1. Verifies that new_saver1 can restore v1.1512 new_saver1 = savers[1]1513 new_saver1.restore(sess, saver1_ckpt)1514 v1 = sess.graph.get_tensor_by_name("v1:0")1515 self.assertEqual(11.0, v1.eval())1516 def testMultiSaverCollection(self):1517 test_dir = self._get_test_dir("saver_collection")1518 self._testMultiSaverCollectionSave(test_dir)1519 self._testMultiSaverCollectionRestore(test_dir)1520 def testClearExtraneousSavers(self):1521 test_dir = self._get_test_dir("clear_extraneous_savers")1522 filename = os.path.join(test_dir, "metafile")1523 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1524 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")1525 with self.test_session(graph=ops_lib.Graph()) as sess:1526 # Creates a graph.1527 v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")1528 v1 = variables.Variable(11.0, name="v1")1529 # Creates 2 savers.1530 saver0 = saver_module.Saver({"v0": v0}, name="saver0")1531 saver1 = saver_module.Saver({"v1": v1}, name="saver1")1532 ops_lib.add_to_collection("savers", saver0)1533 ops_lib.add_to_collection("savers", saver1)1534 variables.global_variables_initializer().run()1535 # Saves to different checkpoints.1536 saver0.save(sess, saver0_ckpt)1537 saver1.save(sess, saver1_ckpt)1538 # Generates MetaGraphDef.1539 meta_graph_def = saver_module.export_meta_graph(filename)1540 meta_graph_def0 = saver0.export_meta_graph()1541 meta_graph_def1 = saver1.export_meta_graph(clear_extraneous_savers=True)1542 # Verifies that there is no saver_def in meta_graph_def.1543 self.assertFalse(meta_graph_def.HasField("saver_def"))1544 # Verifies that there is saver_def in meta_graph_def0 and 1.1545 self.assertTrue(meta_graph_def0.HasField("saver_def"))1546 self.assertTrue(meta_graph_def1.HasField("saver_def"))1547 # Verifies SAVERS is saved as bytes_list for meta_graph_def.1548 collection_def = meta_graph_def.collection_def["savers"]1549 kind = collection_def.WhichOneof("kind")1550 self.assertEqual(kind, "bytes_list")1551 # Verifies that there are 2 entries in SAVERS collection.1552 savers = getattr(collection_def, kind)1553 self.assertEqual(2, len(savers.value))1554 # Verifies SAVERS collection is saved as bytes_list for meta_graph_def1.1555 collection_def = meta_graph_def1.collection_def["savers"]1556 kind = collection_def.WhichOneof("kind")1557 self.assertEqual(kind, "bytes_list")1558 # Verifies that there is 1 entry in SAVERS collection.1559 savers = getattr(collection_def, kind)1560 self.assertEqual(1, len(savers.value))1561 # Verifies that saver0 graph nodes are omitted from the saver1 export1562 self.assertEqual(29, len(meta_graph_def0.graph_def.node))1563 self.assertEqual(19, len(meta_graph_def1.graph_def.node))1564 def testBinaryAndTextFormat(self):1565 test_dir = self._get_test_dir("binary_and_text")1566 filename = os.path.join(test_dir, "metafile")1567 with self.test_session(graph=ops_lib.Graph()):1568 # Creates a graph.1569 variables.Variable(10.0, name="v0")1570 # Exports the graph as binary format.1571 saver_module.export_meta_graph(filename, as_text=False)1572 with self.test_session(graph=ops_lib.Graph()):1573 # Imports the binary format graph.1574 saver = saver_module.import_meta_graph(filename)1575 self.assertIsNotNone(saver)1576 # Exports the graph as text format.1577 saver.export_meta_graph(filename, as_text=True)1578 with self.test_session(graph=ops_lib.Graph()):1579 # Imports the text format graph.1580 saver_module.import_meta_graph(filename)1581 # Writes wrong contents to the file.1582 graph_io.write_graph(saver.as_saver_def(),1583 os.path.dirname(filename),1584 os.path.basename(filename))1585 with self.test_session(graph=ops_lib.Graph()):1586 # Import should fail.1587 with self.assertRaisesWithPredicateMatch(IOError,1588 lambda e: "Cannot parse file"):1589 saver_module.import_meta_graph(filename)1590 # Deletes the file1591 gfile.Remove(filename)1592 with self.assertRaisesWithPredicateMatch(IOError,1593 lambda e: "does not exist"):1594 saver_module.import_meta_graph(filename)1595 def testSliceVariable(self):1596 test_dir = self._get_test_dir("slice_saver")1597 filename = os.path.join(test_dir, "metafile")1598 with self.test_session():1599 v1 = variables.Variable([20.0], name="v1")1600 v2 = variables.Variable([20.0], name="v2")1601 v2._set_save_slice_info(1602 variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))1603 # The names are different and will work.1604 slice_saver = saver_module.Saver({"first": v1, "second": v2})1605 variables.global_variables_initializer().run()1606 # Exports to meta_graph1607 meta_graph_def = slice_saver.export_meta_graph(filename)1608 with ops_lib.Graph().as_default():1609 # Restores from MetaGraphDef.1610 new_saver = saver_module.import_meta_graph(filename)1611 self.assertIsNotNone(new_saver)1612 # Generates a new MetaGraphDef.1613 new_meta_graph_def = new_saver.export_meta_graph()1614 # It should be the same as the original.1615 self.assertProtoEquals(meta_graph_def, new_meta_graph_def)1616 def _testGraphExtensionSave(self, test_dir):1617 filename = os.path.join(test_dir, "metafile")1618 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1619 # Creates an inference graph.1620 # Hidden 11621 images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28])1622 with ops_lib.name_scope("hidden1"):1623 weights = variables.Variable(1624 random_ops.truncated_normal(1625 [28, 128], stddev=1.0 / math.sqrt(float(28))),1626 name="weights")1627 # The use of control_flow_ops.cond here is purely for adding test coverage1628 # the save and restore of control flow context (which doesn't make any1629 # sense here from a machine learning perspective). The typical biases is1630 # a simple Variable without the conditions.1631 biases = variables.Variable(1632 control_flow_ops.cond(1633 math_ops.less(random.random(), 0.5),1634 lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),1635 name="biases")1636 hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases)1637 # Hidden 21638 with ops_lib.name_scope("hidden2"):1639 weights = variables.Variable(1640 random_ops.truncated_normal(1641 [128, 32], stddev=1.0 / math.sqrt(float(128))),1642 name="weights")1643 # The use of control_flow_ops.while_loop here is purely for adding test1644 # coverage the save and restore of control flow context (which doesn't1645 # make any sense here from a machine learning perspective). The typical1646 # biases is a simple Variable without the conditions.1647 def loop_cond(it, _):1648 return it < 21649 def loop_body(it, biases):1650 biases += constant_op.constant(0.1, shape=[32])1651 return it + 1, biases1652 _, biases = control_flow_ops.while_loop(1653 loop_cond, loop_body,1654 [constant_op.constant(0), variables.Variable(array_ops.zeros([32]))])1655 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)1656 # Linear1657 with ops_lib.name_scope("softmax_linear"):1658 weights = variables.Variable(1659 random_ops.truncated_normal(1660 [32, 10], stddev=1.0 / math.sqrt(float(32))),1661 name="weights")1662 biases = variables.Variable(array_ops.zeros([10]), name="biases")1663 logits = math_ops.matmul(hidden2, weights) + biases1664 ops_lib.add_to_collection("logits", logits)1665 init_all_op = variables.global_variables_initializer()1666 with self.test_session() as sess:1667 # Initializes all the variables.1668 sess.run(init_all_op)1669 # Runs to logit.1670 sess.run(logits)1671 # Creates a saver.1672 saver0 = saver_module.Saver()1673 saver0.save(sess, saver0_ckpt)1674 # Generates MetaGraphDef.1675 saver0.export_meta_graph(filename)1676 def _testGraphExtensionRestore(self, test_dir):1677 filename = os.path.join(test_dir, "metafile")1678 train_filename = os.path.join(test_dir, "train_metafile")1679 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1680 with self.test_session(graph=ops_lib.Graph()) as sess:1681 # Restores from MetaGraphDef.1682 new_saver = saver_module.import_meta_graph(filename)1683 # Generates a new MetaGraphDef.1684 new_saver.export_meta_graph()1685 # Restores from checkpoint.1686 new_saver.restore(sess, saver0_ckpt)1687 # Adds loss and train.1688 labels = constant_op.constant(0, dtypes.int32, shape=[100], name="labels")1689 batch_size = array_ops.size(labels)1690 labels = array_ops.expand_dims(labels, 1)1691 indices = array_ops.expand_dims(math_ops.range(0, batch_size), 1)1692 concated = array_ops.concat([indices, labels], 1)1693 onehot_labels = sparse_ops.sparse_to_dense(1694 concated, array_ops.stack([batch_size, 10]), 1.0, 0.0)1695 logits = ops_lib.get_collection("logits")[0]1696 cross_entropy = nn_ops.softmax_cross_entropy_with_logits(1697 labels=onehot_labels, logits=logits, name="xentropy")1698 loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean")1699 summary.scalar("loss", loss)1700 # Creates the gradient descent optimizer with the given learning rate.1701 optimizer = gradient_descent.GradientDescentOptimizer(0.01)1702 # Runs train_op.1703 train_op = optimizer.minimize(loss)1704 ops_lib.add_to_collection("train_op", train_op)1705 # Runs train_op.1706 sess.run(train_op)1707 # Generates MetaGraphDef.1708 saver_module.export_meta_graph(train_filename)1709 def _testRestoreFromTrainGraphWithControlContext(self, test_dir):1710 train_filename = os.path.join(test_dir, "train_metafile")1711 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")1712 with self.test_session(graph=ops_lib.Graph()) as sess:1713 # Restores from MetaGraphDef.1714 new_saver = saver_module.import_meta_graph(train_filename)1715 # Restores from checkpoint.1716 new_saver.restore(sess, saver0_ckpt)1717 train_op = ops_lib.get_collection("train_op")[0]1718 sess.run(train_op)1719 def testGraphExtension(self):1720 test_dir = self._get_test_dir("graph_extension")1721 self._testGraphExtensionSave(test_dir)1722 self._testGraphExtensionRestore(test_dir)1723 self._testRestoreFromTrainGraphWithControlContext(test_dir)1724 def testStrippedOpListDef(self):1725 with self.test_session():1726 # Creates a graph.1727 v0 = variables.Variable(0.0)1728 var = variables.Variable(10.0)1729 math_ops.add(v0, var)1730 @function.Defun(dtypes.float32)1731 def minus_one(x):1732 return x - 11733 minus_one(array_ops.identity(v0))1734 save = saver_module.Saver({"v0": v0})1735 variables.global_variables_initializer()1736 # Generates MetaGraphDef.1737 meta_graph_def = save.export_meta_graph()1738 ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op]1739 if save._write_version is saver_pb2.SaverDef.V1:1740 self.assertEqual(ops, [1741 "Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2",1742 "SaveSlices", "Sub", "VariableV2"1743 ])1744 else:1745 self.assertEqual(ops, [1746 "Add", "Assign", "Const", "Identity", "NoOp", "RestoreV2", "SaveV2",1747 "Sub", "VariableV2"1748 ])1749 # Test calling stripped_op_list_for_graph directly1750 op_list = meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)1751 self.assertEqual(ops, [o.name for o in op_list.op])1752 for o in op_list.op:1753 self.assertEqual(o.summary, "")1754 self.assertEqual(o.description, "")1755 def testImportIntoNamescope(self):1756 # Test that we can import a meta graph into a namescope.1757 test_dir = self._get_test_dir("import_into_namescope")1758 filename = os.path.join(test_dir, "ckpt")1759 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")1760 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")1761 with session.Session() as sess:1762 weights = variables.Variable(1763 random_ops.random_uniform([784, 10]), name="weights")1764 bias = variables.Variable(array_ops.zeros([10]), name="bias")1765 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")1766 nn_ops.softmax(logit, name="prediction")1767 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,1768 logits=logit, name="cost")1769 adam.AdamOptimizer().minimize(cost, name="optimize")1770 saver = saver_module.Saver()1771 sess.run(variables.global_variables_initializer())1772 saver.save(sess, filename)1773 graph = ops_lib.Graph()1774 with session.Session(graph=graph) as sess:1775 new_saver = saver_module.import_meta_graph(1776 filename + ".meta", graph=graph, import_scope="new_model")1777 new_saver.restore(sess, filename)1778 sess.run(["new_model/optimize"], {1779 "new_model/image:0": np.random.random([1, 784]),1780 "new_model/label:0": np.random.randint(1781 10, size=[1, 10])1782 })1783 def testClearDevicesOnImport(self):1784 # Test that we import a graph without its devices and run successfully.1785 with ops_lib.Graph().as_default():1786 with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):1787 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")1788 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")1789 weights = variables.Variable(1790 random_ops.random_uniform([784, 10]), name="weights")1791 bias = variables.Variable(array_ops.zeros([10]), name="bias")1792 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)1793 nn_ops.softmax(logit, name="prediction")1794 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,1795 logits=logit)1796 adam.AdamOptimizer().minimize(cost, name="optimize")1797 meta_graph_def = saver_module.export_meta_graph()1798 with session.Session(graph=ops_lib.Graph()) as sess:1799 saver_module.import_meta_graph(1800 meta_graph_def, clear_devices=False, import_scope="new_model")1801 # Device refers to GPU, which is not available here.1802 with self.assertRaises(errors_impl.InvalidArgumentError):1803 sess.run(variables.global_variables_initializer())1804 with session.Session(graph=ops_lib.Graph()) as sess:1805 saver_module.import_meta_graph(1806 meta_graph_def, clear_devices=True, import_scope="new_model")1807 sess.run(variables.global_variables_initializer())1808 sess.run(["new_model/optimize"], {1809 "new_model/image:0": np.random.random([1, 784]),1810 "new_model/label:0": np.random.randint(1811 10, size=[1, 10])1812 })1813 def testClearDevicesOnExport(self):1814 # Test that we export a graph without its devices and run successfully.1815 with ops_lib.Graph().as_default():1816 with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):1817 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")1818 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")1819 weights = variables.Variable(1820 random_ops.random_uniform([784, 10]), name="weights")1821 bias = variables.Variable(array_ops.zeros([10]), name="bias")1822 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)1823 nn_ops.softmax(logit, name="prediction")1824 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,1825 logits=logit)1826 adam.AdamOptimizer().minimize(cost, name="optimize")1827 meta_graph_def = saver_module.export_meta_graph(clear_devices=True)1828 graph_io.write_graph(meta_graph_def, self.get_temp_dir(),1829 "meta_graph.pbtxt")1830 with session.Session(graph=ops_lib.Graph()) as sess:1831 saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")1832 sess.run(variables.global_variables_initializer())1833 sess.run(["new_model/optimize"], {1834 "new_model/image:0": np.random.random([1, 784]),1835 "new_model/label:0": np.random.randint(1836 10, size=[1, 10])1837 })1838class CheckpointReaderTest(test.TestCase):1839 _WRITE_VERSION = saver_pb2.SaverDef.V11840 def testDebugString(self):1841 # Builds a graph.1842 v0 = variables.Variable(1843 [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")1844 v1 = variables.Variable(1845 [[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1")1846 init_all_op = variables.global_variables_initializer()1847 save = saver_module.Saver(1848 {1849 "v0": v0,1850 "v1": v11851 }, write_version=self._WRITE_VERSION)1852 save_path = os.path.join(self.get_temp_dir(),1853 "ckpt_for_debug_string" + str(self._WRITE_VERSION))1854 with self.test_session() as sess:1855 sess.run(init_all_op)1856 # Saves a checkpoint.1857 save.save(sess, save_path)1858 # Creates a reader.1859 reader = pywrap_tensorflow.NewCheckpointReader(save_path)1860 # Verifies that the tensors exist.1861 self.assertTrue(reader.has_tensor("v0"))1862 self.assertTrue(reader.has_tensor("v1"))1863 debug_string = reader.debug_string()1864 # Verifies that debug string contains the right strings.1865 self.assertTrue(compat.as_bytes("v0 (DT_FLOAT) [2,3]") in debug_string)1866 self.assertTrue(compat.as_bytes("v1 (DT_FLOAT) [3,2,1]") in debug_string)1867 # Verifies get_variable_to_shape_map() returns the correct information.1868 var_map = reader.get_variable_to_shape_map()1869 self.assertEqual([2, 3], var_map["v0"])1870 self.assertEqual([3, 2, 1], var_map["v1"])1871 # Verifies get_tensor() returns the tensor value.1872 v0_tensor = reader.get_tensor("v0")1873 v1_tensor = reader.get_tensor("v1")1874 self.assertAllEqual(v0.eval(), v0_tensor)1875 self.assertAllEqual(v1.eval(), v1_tensor)1876 # Verifies get_tensor() fails for non-existent tensors.1877 with self.assertRaisesRegexp(errors.NotFoundError,1878 "v3 not found in checkpoint"):1879 reader.get_tensor("v3")1880 def testNonexistentPath(self):1881 with self.assertRaisesRegexp(errors.NotFoundError,1882 "Unsuccessful TensorSliceReader"):1883 pywrap_tensorflow.NewCheckpointReader("non-existent")1884class CheckpointReaderForV2Test(CheckpointReaderTest):1885 _WRITE_VERSION = saver_pb2.SaverDef.V21886class WriteGraphTest(test.TestCase):1887 def _get_test_dir(self, dirname):1888 test_dir = os.path.join(self.get_temp_dir(), dirname)1889 gfile.MakeDirs(test_dir)1890 return test_dir1891 def testWriteGraph(self):1892 test_dir = self._get_test_dir("write_graph_dir")1893 variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")1894 path = graph_io.write_graph(ops_lib.get_default_graph(),1895 os.path.join(test_dir, "l1"), "graph.pbtxt")1896 truth = os.path.join(test_dir, "l1", "graph.pbtxt")1897 self.assertEqual(path, truth)1898 self.assertTrue(os.path.exists(path))1899 def testRecursiveCreate(self):1900 test_dir = self._get_test_dir("deep_dir")1901 variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")1902 path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),1903 os.path.join(test_dir, "l1", "l2", "l3"),1904 "graph.pbtxt")1905 truth = os.path.join(test_dir, "l1", "l2", "l3", "graph.pbtxt")1906 self.assertEqual(path, truth)1907 self.assertTrue(os.path.exists(path))1908class SaverUtilsTest(test.TestCase):1909 def setUp(self):1910 self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")1911 gfile.MakeDirs(self._base_dir)1912 def tearDown(self):1913 gfile.DeleteRecursively(self._base_dir)1914 def testCheckpointExists(self):1915 for sharded in (False, True):1916 for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):1917 with self.test_session(graph=ops_lib.Graph()) as sess:1918 unused_v = variables.Variable(1.0, name="v")1919 variables.global_variables_initializer().run()1920 saver = saver_module.Saver(sharded=sharded, write_version=version)1921 path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))1922 self.assertFalse(1923 saver_module.checkpoint_exists(path)) # Not saved yet.1924 ckpt_prefix = saver.save(sess, path)1925 self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))1926 ckpt_prefix = saver_module.latest_checkpoint(self._base_dir)1927 self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))1928 def testGetCheckpointMtimes(self):1929 prefixes = []1930 for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):1931 with self.test_session(graph=ops_lib.Graph()) as sess:1932 unused_v = variables.Variable(1.0, name="v")1933 variables.global_variables_initializer().run()1934 saver = saver_module.Saver(write_version=version)1935 prefixes.append(1936 saver.save(sess, os.path.join(self._base_dir, str(version))))1937 mtimes = saver_module.get_checkpoint_mtimes(prefixes)1938 self.assertEqual(2, len(mtimes))1939 self.assertTrue(mtimes[1] >= mtimes[0])1940class ScopedGraphTest(test.TestCase):1941 def _get_test_dir(self, dirname):1942 test_dir = os.path.join(self.get_temp_dir(), dirname)1943 gfile.MakeDirs(test_dir)1944 return test_dir1945 def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):1946 graph = ops_lib.Graph()1947 with graph.as_default():1948 # Creates an inference graph.1949 # Hidden 11950 images = constant_op.constant(1951 1.2, dtypes.float32, shape=[100, 28], name="images")1952 with ops_lib.name_scope("hidden1"):1953 weights1 = variables.Variable(1954 random_ops.truncated_normal(1955 [28, 128], stddev=1.0 / math.sqrt(float(28))),1956 name="weights")1957 # The use of control_flow_ops.cond here is purely for adding test1958 # coverage the save and restore of control flow context (which doesn't1959 # make any sense here from a machine learning perspective). The typical1960 # biases is a simple Variable without the conditions.1961 biases1 = variables.Variable(1962 control_flow_ops.cond(1963 math_ops.less(random.random(), 0.5),1964 lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),1965 name="biases")1966 hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1)1967 # Hidden 21968 with ops_lib.name_scope("hidden2"):1969 weights2 = variables.Variable(1970 random_ops.truncated_normal(1971 [128, 32], stddev=1.0 / math.sqrt(float(128))),1972 name="weights")1973 # The use of control_flow_ops.while_loop here is purely for adding test1974 # coverage the save and restore of control flow context (which doesn't1975 # make any sense here from a machine learning perspective). The typical1976 # biases is a simple Variable without the conditions.1977 def loop_cond(it, _):1978 return it < 21979 def loop_body(it, biases2):1980 biases2 += constant_op.constant(0.1, shape=[32])1981 return it + 1, biases21982 _, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [1983 constant_op.constant(0), variables.Variable(array_ops.zeros([32]))1984 ])1985 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)1986 # Linear1987 with ops_lib.name_scope("softmax_linear"):1988 weights3 = variables.Variable(1989 random_ops.truncated_normal(1990 [32, 10], stddev=1.0 / math.sqrt(float(32))),1991 name="weights")1992 biases3 = variables.Variable(array_ops.zeros([10]), name="biases")1993 logits = math_ops.matmul(hidden2, weights3) + biases31994 ops_lib.add_to_collection("logits", logits)1995 # Adds user_defined proto in three formats: string, bytes and Any.1996 # Any proto should just pass through.1997 queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")1998 ops_lib.add_to_collection("user_defined_string_collection",1999 str(queue_runner))2000 ops_lib.add_to_collection("user_defined_bytes_collection",2001 queue_runner.SerializeToString())2002 any_buf = Any()2003 any_buf.Pack(queue_runner)2004 ops_lib.add_to_collection("user_defined_any_collection", any_buf)2005 _, var_list = meta_graph.export_scoped_meta_graph(2006 filename=os.path.join(test_dir, exported_filename),2007 graph=ops_lib.get_default_graph(),2008 export_scope="hidden1")2009 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))2010 with self.test_session(graph=graph) as sess:2011 sess.run(variables.global_variables_initializer())2012 saver = saver_module.Saver(var_list=var_list, max_to_keep=1)2013 saver.save(sess, os.path.join(test_dir, ckpt_filename), write_state=False)2014 def _testScopedRestore(self, test_dir, exported_filename,2015 new_exported_filename, ckpt_filename):2016 graph = ops_lib.Graph()2017 # Create all the missing inputs.2018 with graph.as_default():2019 new_image = constant_op.constant(2020 1.2, dtypes.float32, shape=[100, 28], name="images")2021 var_list = meta_graph.import_scoped_meta_graph(2022 os.path.join(test_dir, exported_filename),2023 graph=graph,2024 input_map={"$unbound_inputs_images": new_image},2025 import_scope="new_hidden1")2026 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))2027 hidden1 = graph.as_graph_element("new_hidden1/Relu:0")2028 weights1 = graph.as_graph_element("new_hidden1/weights:0")2029 biases1 = graph.as_graph_element("new_hidden1/biases:0")2030 with graph.as_default():2031 # Hidden 22032 with ops_lib.name_scope("hidden2"):2033 weights = variables.Variable(2034 random_ops.truncated_normal(2035 [128, 32], stddev=1.0 / math.sqrt(float(128))),2036 name="weights")2037 # The use of control_flow_ops.while_loop here is purely for adding test2038 # coverage the save and restore of control flow context (which doesn't2039 # make any sense here from a machine learning perspective). The typical2040 # biases is a simple Variable without the conditions.2041 def loop_cond(it, _):2042 return it < 22043 def loop_body(it, biases):2044 biases += constant_op.constant(0.1, shape=[32])2045 return it + 1, biases2046 _, biases = control_flow_ops.while_loop(loop_cond, loop_body, [2047 constant_op.constant(0), variables.Variable(array_ops.zeros([32]))2048 ])2049 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)2050 # Linear2051 with ops_lib.name_scope("softmax_linear"):2052 weights = variables.Variable(2053 random_ops.truncated_normal(2054 [32, 10], stddev=1.0 / math.sqrt(float(32))),2055 name="weights")2056 biases = variables.Variable(array_ops.zeros([10]), name="biases")2057 logits = math_ops.matmul(hidden2, weights) + biases2058 ops_lib.add_to_collection("logits", logits)2059 # The rest of the variables.2060 rest_variables = list(2061 set(variables.global_variables()) - set(var_list.keys()))2062 init_rest_op = variables.initialize_variables(rest_variables)2063 with self.test_session(graph=graph) as sess:2064 saver = saver_module.Saver(var_list=var_list, max_to_keep=1)2065 saver.restore(sess, os.path.join(test_dir, ckpt_filename))2066 # Verify that we have restored weights1 and biases1.2067 sess.run([weights1, biases1])2068 # Initialize the rest of the variables and run logits.2069 sess.run(init_rest_op)2070 sess.run(logits)2071 # Verifies that we can save the subgraph under "hidden1" and restore it2072 # into "new_hidden1" in the new graph.2073 def testScopedSaveAndRestore(self):2074 test_dir = self._get_test_dir("scoped_export_import")2075 ckpt_filename = "ckpt"2076 self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)2077 self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",2078 "exported_new_hidden1.pbtxt", ckpt_filename)2079 # Verifies that we can copy the subgraph under "hidden1" and copy it2080 # to different name scope in the same graph or different graph.2081 def testCopyScopedGraph(self):2082 test_dir = self._get_test_dir("scoped_copy")2083 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")2084 graph1 = ops_lib.Graph()2085 with graph1.as_default():2086 with ops_lib.name_scope("hidden1"):2087 images = constant_op.constant(2088 1.0, dtypes.float32, shape=[3, 2], name="images")2089 weights1 = variables.Variable(2090 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")2091 biases1 = variables.Variable([0.1] * 3, name="biases")2092 nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")2093 # Run the graph and save scoped checkpoint.2094 with self.test_session(graph=graph1) as sess:2095 sess.run(variables.global_variables_initializer())2096 _, var_list_1 = meta_graph.export_scoped_meta_graph(2097 export_scope="hidden1")2098 saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)2099 saver.save(sess, saver0_ckpt, write_state=False)2100 expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))2101 # Verifies copy to the same graph with the same name fails.2102 with graph1.as_default():2103 with self.assertRaisesWithPredicateMatch(2104 ValueError, lambda e: "need to be different" in str(e)):2105 meta_graph.copy_scoped_meta_graph(2106 from_scope="hidden1", to_scope="hidden1")2107 # Verifies copy to the same graph.2108 with graph1.as_default():2109 var_list_2 = meta_graph.copy_scoped_meta_graph(2110 from_scope="hidden1", to_scope="hidden2")2111 with self.test_session(graph=graph1) as sess:2112 saver1 = saver_module.Saver(var_list=var_list_1, max_to_keep=1)2113 saver1.restore(sess, saver0_ckpt)2114 saver2 = saver_module.Saver(var_list=var_list_2, max_to_keep=1)2115 saver2.restore(sess, saver0_ckpt)2116 self.assertAllClose(expected, sess.run("hidden1/relu:0"))2117 self.assertAllClose(expected, sess.run("hidden2/relu:0"))2118 # Verifies copy to differen graph.2119 graph2 = ops_lib.Graph()2120 new_var_list_1 = meta_graph.copy_scoped_meta_graph(2121 from_scope="hidden1",2122 to_scope="new_hidden1",2123 from_graph=graph1,2124 to_graph=graph2)2125 with self.test_session(graph=graph2) as sess:2126 saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)2127 saver3.restore(sess, saver0_ckpt)2128 self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))2129 def testExportGraphDefWithScope(self):2130 test_dir = self._get_test_dir("export_graph_def")2131 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")2132 graph1 = ops_lib.Graph()2133 with graph1.as_default():2134 with ops_lib.name_scope("hidden1"):2135 images = constant_op.constant(2136 1.0, dtypes.float32, shape=[3, 2], name="images")2137 weights1 = variables.Variable(2138 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")2139 biases1 = variables.Variable([0.1] * 3, name="biases")2140 nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")2141 # Run the graph and save scoped checkpoint.2142 with self.test_session(graph=graph1) as sess:2143 sess.run(variables.global_variables_initializer())2144 _, var_list_1 = meta_graph.export_scoped_meta_graph(2145 graph_def=graph1.as_graph_def(), export_scope="hidden1")2146 saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)2147 saver.save(sess, saver0_ckpt, write_state=False)2148 expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))2149 # Verifies that we can run successfully after restoring.2150 graph2 = ops_lib.Graph()2151 new_var_list_1 = meta_graph.copy_scoped_meta_graph(2152 from_scope="hidden1",2153 to_scope="new_hidden1",2154 from_graph=graph1,2155 to_graph=graph2)2156 with self.test_session(graph=graph2) as sess:2157 saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)2158 saver3.restore(sess, saver0_ckpt)2159 self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))2160 def testSerializeSaverWithScope(self):2161 test_dir = self._get_test_dir("export_graph_def")2162 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")2163 saver2_ckpt = os.path.join(test_dir, "saver2.ckpt")2164 graph = ops_lib.Graph()2165 with graph.as_default():2166 with ops_lib.name_scope("hidden1"):2167 variable1 = variables.Variable([1.0], name="variable1")2168 saver1 = saver_module.Saver(var_list=[variable1])2169 graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1)2170 with ops_lib.name_scope("hidden2"):2171 variable2 = variables.Variable([2.0], name="variable2")2172 saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/")2173 graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2)2174 with self.test_session(graph=graph) as sess:2175 variables.global_variables_initializer().run()2176 saver1.save(sess, saver1_ckpt, write_state=False)2177 saver2.save(sess, saver2_ckpt, write_state=False)2178 graph1 = ops_lib.Graph()2179 var_dict1 = meta_graph.copy_scoped_meta_graph(2180 from_scope="hidden1",2181 to_scope="new_hidden1",2182 from_graph=graph,2183 to_graph=graph1)2184 self.assertEqual(1, len(var_dict1))2185 saver_list1 = graph1.get_collection(ops_lib.GraphKeys.SAVERS)2186 self.assertEqual(1, len(saver_list1))2187 with self.test_session(graph=graph1) as sess:2188 saver_list1[0].restore(sess, saver1_ckpt)2189 self.assertEqual(1.0, var_dict1["variable1:0"].eval())2190 graph2 = ops_lib.Graph()2191 var_dict2 = meta_graph.copy_scoped_meta_graph(2192 from_scope="hidden2",2193 to_scope="new_hidden2",2194 from_graph=graph,2195 to_graph=graph2)2196 self.assertEqual(1, len(var_dict2))2197 saver_list2 = graph2.get_collection(ops_lib.GraphKeys.SAVERS)2198 self.assertEqual(1, len(saver_list2))2199 with self.test_session(graph=graph2) as sess:2200 saver_list2[0].restore(sess, saver2_ckpt)2201 self.assertEqual(2.0, var_dict2["variable2:0"].eval())2202# TODO(b/64763924): Remove after Jan 1st 2018.2203class LenientNamesTest(test.TestCase):2204 def setUp(self):2205 super(LenientNamesTest, self).setUp()2206 os.putenv("TF_SAVER_LENIENT_NAMES", "True")2207 def tearDown(self):2208 os.putenv("TF_SAVER_LENIENT_NAMES", "")2209 super(LenientNamesTest, self).tearDown()2210 def testSaveRestore(self):2211 save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")2212 # Build a graph with 2 parameter nodes, and Save and2213 # Restore nodes for them.2214 v0 = variables.Variable(10.0, name="v0")2215 v1 = variables.Variable(20.0, name="v1")2216 v2 = saver_test_utils.CheckpointedOp(name="v2")2217 v2_init = v2.insert("k1", 30.0)2218 save = saver_module.Saver(2219 {2220 "v0:0": v0,2221 "v1": v1,2222 "v2": v2.saveable2223 }, restore_sequentially=True)2224 init_all_op = [variables.global_variables_initializer(), v2_init]2225 with self.test_session() as sess:2226 sess.run(init_all_op)2227 save.save(sess, save_path)2228 with self.test_session() as sess:2229 v0 = variables.Variable(-1.0, name="v0")2230 v1 = variables.Variable(-1.0, name="v1")2231 v2 = saver_test_utils.CheckpointedOp(name="v2")2232 save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})2233 save.restore(sess, save_path)2234 # Check that the parameter nodes have been restored.2235 self.assertEqual(10.0, v0.eval())2236 self.assertEqual(20.0, v1.eval())2237 self.assertEqual(b"k1", v2.keys().eval())2238 self.assertEqual(30.0, v2.values().eval())2239if __name__ == "__main__":...
whole_game_work_in_progress.py
Source:whole_game_work_in_progress.py
1import pygame2from buildings import Factory3from buildings import Windturbine4from buildings import Cleaning_station56pygame.init()78screen = pygame.display.set_mode((800, 600))9background_color = (101, 56, 24)10screen.fill(background_color)11clock = pygame.time.Clock()12counter = 013income = 0141516def button(screen, position, text):17 font = pygame.font.SysFont("Arial", 50)18 text_render = font.render(text, 1, (0, 0, 0))19 x, y, w, h = text_render.get_rect()20 x, y = position21 pygame.draw.line(screen, (50, 50, 50), (x, y + h), (x + w, y + h), 5)22 pygame.draw.line(screen, (50, 50, 50), (x + w, y + h), [x + w, y], 5)23 pygame.draw.rect(screen, (128, 128, 128), (x, y, w, h))24 return screen.blit(text_render, (x, y))252627def button2(screen, position, text, color):28 global text_render, Button_Img29 font = pygame.font.SysFont("Arial", 24)30 if color == "blue":31 text_render = font.render(text, 1, (0, 0, 255))32 elif color == "green":33 text_render = font.render(text, 1, (0, 255, 0))34 elif color == "red":35 text_render = font.render(text, 1, (255, 0, 0))36 elif color == "pink":37 text_render = font.render(text, 1, (255, 182, 193))38 x, y, w, h = text_render.get_rect()39 x, y = position40 pygame.draw.line(screen, (50, 50, 50), (x, y + h), (x + w, y + h), 5)41 pygame.draw.line(screen, (50, 50, 50), (x + w, y + h), [x + w, y], 5)42 pygame.draw.rect(screen, (100, 100, 100), (x, y, w, h))43 return screen.blit(text_render, (x, y))444546def buildings_menu():47 global buy_factory, buy_cleaning, buy_windturbine, upgrade48 buy_factory = button2(screen, (500, 100), " BUY FACTORY: 1000 ", "red")49 buy_cleaning = button2(screen, (500, 150), " BUY CL. STATION: 3500 ", "blue")50 buy_windturbine = button2(screen, (500, 200), " BUY WINDMILL: 2000 ", "green")51 upgrade = button2(screen, (500, 300), " UPGRADE: 10000 ", "red")525354def buy_menu():55 global buy_land56 buy_land = button2(screen, (500, 250), " BUY: 1500 ", "pink")575859def menu():60 global buy_factory, buy_cleaning, buy_windturbine, buy_land, counter, upgrade, clock, income, save_i, save_j, Button_Img61 factory = Factory(1000, 0.03, 0, 1.5)62 windturbine = Windturbine(2000, 0, 0, 0.5)63 cleaning_station = Cleaning_station(3500, 0.055, 0, 0)64 color = (101, 56, 24)65 money = 7000066 counter_factory = 067 counter_windturbine = 068 counter_cleaning_station = 069 counter_factory_upgrade = 070 counter_cleaning_station_upgrade = 071 counter_windturbine_upgrade = 072 polution = 073 b = [[11, 12, 5, 2], [15, 6, 10, 3], [10, 8, 12, 6], [12, 15, 8, 69], [12, 15, 8, 69]]74 bought = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]7576 def display_text():77 font = pygame.font.SysFont("Stencil", 40)7879 bar_color = (119, 136, 153)80 pygame.draw.rect(screen, color, pygame.Rect(610, 35, 170, 45))81 pygame.display.flip()82 score_display = font.render(f"MONEY : {money}", 1, (8, 255, 8))83 screen.blit(score_display, (450, 35))84 pygame.draw.rect(screen, bar_color, [80, 30, polution * 3, 30])8586 pygame.display.update()8788 for i in range(4): # i e redove89 for j in range(4): # j e koloni90 if i == 1:91 b[i][j] = button(screen, ((90 * (j + 1)), 100), " ") # first one x,second one y92 elif i == 2:93 b[i][j] = button(screen, ((90 * (j + 1)), 200), " ")94 elif i == 3:95 b[i][j] = button(screen, ((90 * (j + 1)), 300), " ")96 else:97 b[i][j] = button(screen, ((90 * (j + 1)), 400), " ")98 running = 099 while running == 0:100 mx, my = pygame.mouse.get_pos()101 money += income102 if counter_cleaning_station > 0 or counter_cleaning_station_upgrade > 0 and polution > 0:103 polution = polution + factory.polution * (counter_factory + counter_factory_upgrade) - (cleaning_station.polution * counter_cleaning_station + 3 * cleaning_station.polution * counter_cleaning_station_upgrade)104 polution = round(polution, 2)105 elif counter_cleaning_station == 0 and counter_cleaning_station_upgrade == 0:106 polution = polution + factory.polution * (counter_factory + counter_factory_upgrade)107 elif polution <= 0:108 polution = polution + 0109110 display_text()111112 if polution > 100000000:113 income = 0114 money = 0115 running = 1116117 for event in pygame.event.get():118 if event.type == pygame.QUIT:119 pygame.quit()120 if event.type == pygame.MOUSEBUTTONDOWN:121 for i in range(4):122 for j in range(4):123 if b[i][j].collidepoint(mx, my):124 buy_menu()125 buildings_menu()126 save_i = i127 save_j = j128 pass129130 if buy_land.collidepoint(mx, my) and money >= 1500 and bought[save_i][save_j] == 0:131 pygame.draw.rect(screen, (255, 182, 193), b[save_i][save_j])132 money = money - 1500133 bought[save_i][save_j] = 10134 if buy_factory.collidepoint(mx, my) and money >= factory.cost and bought[save_i][save_j] == 10 and bought[save_i][save_j] != 11 and bought[save_i][save_j] != 12 and bought[save_i][save_j] != 13 and bought[save_i][save_j] != 21 and bought[save_i][save_j] != 22 and bought[save_i][save_j] != 23:135 pygame.draw.rect(screen, (255, 0, 0), b[save_i][save_j])136 money = money - factory.cost137 counter_factory = counter_factory + 1138 bought[save_i][save_j] = 11139 Button_Img = pygame.image.load("buildings/factory.png")140 if save_j <= 3:141 screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))142 elif save_i == 3:143 screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))144 if buy_cleaning.collidepoint(mx, my) and money >= cleaning_station.cost and bought[save_i][save_j] == 10 and bought[save_i][save_j] != 11 and bought[save_i][save_j] != 12 and bought[save_i][save_j] != 13 and bought[save_i][save_j] != 21 and bought[save_i][save_j] != 22 and bought[save_i][save_j] != 23:145 pygame.draw.rect(screen, (173, 216, 230), b[save_i][save_j])146 money = money - cleaning_station.cost147 counter_cleaning_station = counter_cleaning_station + 1148 bought[save_i][save_j] = 12149 Button_Img = pygame.image.load("buildings/cleaning.png")150 if save_j <= 2:151 screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))152 elif save_i == 3:153 screen.blit(Button_Img, (90 * (save_j + 1), 400))154 if buy_windturbine.collidepoint(mx, my) and money >= windturbine.cost and bought[save_i][save_j] == 10 and bought[save_i][save_j] != 11 and bought[save_i][save_j] != 12 and bought[save_i][save_j] != 13 and bought[save_i][save_j] != 21 and bought[save_i][save_j] != 22 and bought[save_i][save_j] != 23:155 pygame.draw.rect(screen, (0, 255, 0), b[save_i][save_j])156 money = money - windturbine.cost157 counter_windturbine = counter_windturbine + 1158 bought[save_i][save_j] = 13159 Button_Img = pygame.image.load("buildings/windturbine.jpeg")160 if save_j <= 2:161 screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))162 elif save_j == 3:163 screen.blit(Button_Img, (90 * (save_j + 1), 400))164 if upgrade.collidepoint(mx, my) and money >= 10000 and bought[save_i][save_j] == 11:165 pygame.draw.rect(screen, (139, 0, 0), b[save_i][save_j])166 money = money - 10000167 counter_factory_upgrade = counter_factory_upgrade + 1168 counter_factory = counter_factory - 1169 bought[save_i][save_j] = 21170 Button_Img = pygame.image.load("buildings/factory_upgrade.png")171 if save_j <= 2:172 screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))173 elif save_j == 3:174 screen.blit(Button_Img, (90 * (save_j + 1), 400))175 if upgrade.collidepoint(mx, my) and money >= 10000 and bought[save_i][save_j] == 12:176 pygame.draw.rect(screen, (0, 0, 255), b[save_i][save_j])177 money = money - 10000178 counter_cleaning_station_upgrade = counter_cleaning_station_upgrade + 1179 counter_cleaning_station = counter_cleaning_station - 1180 bought[save_i][save_j] = 22181 if upgrade.collidepoint(mx, my) and money >= 10000 and bought[save_i][save_j] == 13:182 pygame.draw.rect(screen, (149, 255, 128), b[save_i][save_j])183 money = money - 10000184 counter_windturbine_upgrade = counter_windturbine_upgrade + 1185 counter_windturbine = counter_windturbine - 1186 bought[save_i][save_j] = 23187 Button_Img = pygame.image.load("buildings/windturbine_upgrade.jpg")188 if save_j <= 2:189 screen.blit(Button_Img, (90 * (save_j + 1), 100 * save_i))190 elif save_i == 3:191 screen.blit(Button_Img, (90 * (save_j + 1), 400))192 else:193 pass194195 income = factory.income * counter_factory + windturbine.income * counter_windturbine + 1.5 * counter_factory_upgrade * factory.income + 1.5 * counter_windturbine_upgrade * windturbine.income196 money = round(money, 0)197 clock.tick(60)198199
...
populate_database.py
Source:populate_database.py
...33 python_module='tools.Setup_workdir.setupWorkDir',34 hidden=True,35 tool_folder_name="Setup_workdir"36 )37module1.save()38module2 = Module(name="Untar",39 type='0',40 filter='.*(\.tar)',41 form=[{"type":"checkbox", "label":"Verbose", "identifier":"verbose", "value": "-v"}],42 description="Untar all .tar fieles into the current folder",43 tool_folder_name="Untar_archive",44 command="tar xf #file -C #workdir #verbose ",45 hidden=True46 )47module2.save()48module3 = Module(name="ClamAV",49 type='0',50 form=[{"type":"checkbox", "label":"Only show infected files", "identifier":"only_found", "value":"-i"},{"type":"checkbox", "label":"Remove infected files", "identifier":"remove", "value":"--remove"}],51 # command='[{"value":"clamscan","type":"text"},{"type":"text","value":"-r"},{"value":"-i","type":"var","name":"only_found"},{"value":"--remove","type":"var","name":"remove"},{"type":"var","name":"workdir"}]',52 command="clamscan -r #only_found #remove #workdir",53 tool_folder_name="ClamAV"54 )55module3.save()56module8 = Module(name="DROID",57 type='3',58 form=[],59 command="/run.sh \"#file\"",60 filter='.*',61 tool_folder_name="SMART_DROID",62 docker_mount_point="/workdir",63 resultFilter=[{"type": "Containing","value": "[\\w\\W]*Missmatch: \"false\"[\\w\\W]*"},{"type": "Not containing","value": "[\\w\\W]*Missmatch: \"true\"[\\w\\W]*"}]64 )65module8.save()66module11 = Module(name="Unoconv",67 type='3',68 form=[],69 command="UNOPATH=/usr/lib/Libreoffice /usr/bin/python3 /usr/bin/unoconv -f pdf -e SelectPdfVersion=1 \"#file\"",70 # command="seq -s= 100000|tr -d '[:digit:]'",71 filter='.*(\\.(doc|docx))$',72 tool_folder_name="SMART_UNOCONV",73 docker_mount_point="/workdir",74 parallell_jobs=675 )76module11.save()77module12 = Module(name="Verapdf",78 type='3',79 form=[],80 command="verapdf -f 1a \"#file\"",81 filter='.*(\\.pdf)',82 tool_folder_name="SMART_VERAPDF",83 docker_mount_point="/workdir",84 resultFilter=[{"type":"Containing", "value": "[\\w\\W]*compliant=\"1\"[\\w\\W]*"}],85 parallell_jobs=686 )87module12.save()88# setup default templates89template1 = Template(name="Default Start", hidden=True)90template1.save()91process2 = Process(order=1,92 template=template1,93 module=module2,94 value={"verbose": True}95 )96process2.save()97template2 = Template(name="Default Done", hidden=True)98template2.save()99template3 = Template(name="Empty template")100template3.save()101template4 = Template(name="Convert pdf")102template4.save()103process = Process(order=0,104 template=template4,105 module=module11) # unoconv106process.save()107process = Process(order=1,108 template=template4,109 module=module12) # verapdf110process.save()111# create default variables112var = Variable(name="total_number_of_files", data="0")113var.save()114var = Variable(name="total_size", data="0")115var.save()116var = Variable(name="total_number_of_packages", data="0")117var.save()118var = Variable(name="total_number_of_errors", data="0")119var.save()120# system variables121var = Variable(name="work_dir_path", data="/code/workdir")122var.save()123var = Variable(name="packages_path", data="/code/packages")124var.save()125var = Variable(name="tools_path", data="/code/tools")126var.save()127var = Variable(name="premis_file_name", data="log/app_log.xml")128var.save()129var = Variable(name="work_dir_path_host", data="/Users/axenu/Projects/Sydarkivera/APP/workdir")130var.save()131var = Variable(name="premis_template_path", data="/code/templates/premis.json")132var.save()133var = Variable(name="premis_event_template_path", data="/code/templates/premisEvent.json")134var.save()135# default test data, TODO remove in production136# ftype = FileType(name="PDF", errors=3, total=100, size=1203000)137# ftype.save()138# ftype = FileType(name="JPG", errors=33, total=10, size=12033400)139# ftype.save()140# ftype = FileType(name="XML", errors=0, total=43, size=120340)141# ftype.save()142# ftype = FileType(name="XSD", errors=0, total=12, size=120300123)143# ftype.save()144# ftype = FileType(name="sdf", errors=0, total=12, size=120300123)145# ftype.save()146# ftype = FileType(name="Xdh4eSD", errors=0, total=12, size=120300123)147# ftype.save()148# ftype = FileType(name="fgdh", errors=0, total=12, size=120300123)149# ftype.save()150# ftype = FileType(name="wert", errors=0, total=12, size=120300123)151# ftype.save()152# graph = GraphData(date=(datetime.date.today() - datetime.timedelta(days=21)), size=1110000000, count=37954)153# graph.save()154# graph = GraphData(date=(datetime.date.today() - datetime.timedelta(days=14)), size=5834400000, count=9754)155# graph.save()156# graph = GraphData(date=(datetime.date.today() - datetime.timedelta(days=7)), size=2340000000, count=751)157# graph.save()158# graph = GraphData(date=datetime.date.today(), size=300000000, count=3452)159# graph.save()160#create default docker images161# image = DockerImage(name="vera_pdf", mountpoint="/workdir", label="verapdf")162# image.save()163# module12.dockerImage = image164# module12.save()165image = DockerImage(name="axenu/app-worker-droid", mountpoint="/workdir", label="Droid")166image.save()167module8.dockerImage = image168module8.save()169image = DockerImage(name="axenu/app-worker-unoconv", mountpoint="/workdir", label="Unoconv")170image.save()171module11.dockerImage = image172module11.save()173image = DockerImage(name="axenu/app-worker-verapdf", mountpoint="/workdir", label="VeraPDF")174image.save()175module12.dockerImage = image176module12.save()177# create default admin users178User.objects.all().delete()179user = User.objects.create_user('admin', 'simon@axenu.com', 'admin')180user.is_superuser = True181user.is_staff = True182# user.role = 2183user.save()184print('populate_databse finished')185#create new package.186# package1 = Package(name="demo paket 1", path="/Users/axenu/Sydarkivera/toolbox/paket/af268c33-5ba8-4af5-9a44-039b10126835.tar", file_name="af268c33-5ba8-4af5-9a44-039b10126835.tar", status=0)187# package1.save()188# create some processes189# process1 = Process(order=1, package=package1, module=module1, value='{}')190# process1.save()191# process2 = Process(order=2, package=package1, module=module2, value='{}')...
urls.py
Source:urls.py
1from django.urls import path, include2from . import views3from .import HodViews, StaffViews, StudentViews4urlpatterns = [5 path('', views.loginPage, name="login"),6 # path('accounts/', include('django.contrib.auth.urls')),7 path('doLogin/', views.doLogin, name="doLogin"),8 path('get_user_details/', views.get_user_details, name="get_user_details"),9 path('logout_user/', views.logout_user, name="logout_user"),10 path('admin_home/', HodViews.admin_home, name="admin_home"),11 path('add_staff/', HodViews.add_staff, name="add_staff"),12 path('add_staff_save/', HodViews.add_staff_save, name="add_staff_save"),13 path('manage_staff/', HodViews.manage_staff, name="manage_staff"),14 path('edit_staff/<staff_id>/', HodViews.edit_staff, name="edit_staff"),15 path('edit_staff_save/', HodViews.edit_staff_save, name="edit_staff_save"),16 path('delete_staff/<staff_id>/', HodViews.delete_staff, name="delete_staff"),17 path('add_course/', HodViews.add_course, name="add_course"),18 path('add_course_save/', HodViews.add_course_save, name="add_course_save"),19 path('manage_course/', HodViews.manage_course, name="manage_course"),20 path('edit_course/<course_id>/', HodViews.edit_course, name="edit_course"),21 path('edit_course_save/', HodViews.edit_course_save, name="edit_course_save"),22 path('delete_course/<course_id>/', HodViews.delete_course, name="delete_course"),23 path('manage_session/', HodViews.manage_session, name="manage_session"),24 path('add_session/', HodViews.add_session, name="add_session"),25 path('add_session_save/', HodViews.add_session_save, name="add_session_save"),26 path('edit_session/<session_id>', HodViews.edit_session, name="edit_session"),27 path('edit_session_save/', HodViews.edit_session_save, name="edit_session_save"),28 path('delete_session/<session_id>/', HodViews.delete_session, name="delete_session"),29 path('add_student/', HodViews.add_student, name="add_student"),30 path('add_student_save/', HodViews.add_student_save, name="add_student_save"),31 path('edit_student/<student_id>', HodViews.edit_student, name="edit_student"),32 path('edit_student_save/', HodViews.edit_student_save, name="edit_student_save"),33 path('manage_student/', HodViews.manage_student, name="manage_student"),34 path('delete_student/<student_id>/', HodViews.delete_student, name="delete_student"),35 path('add_subject/', HodViews.add_subject, name="add_subject"),36 path('add_subject_save/', HodViews.add_subject_save, name="add_subject_save"),37 path('manage_subject/', HodViews.manage_subject, name="manage_subject"),38 path('edit_subject/<subject_id>/', HodViews.edit_subject, name="edit_subject"),39 path('edit_subject_save/', HodViews.edit_subject_save, name="edit_subject_save"),40 path('delete_subject/<subject_id>/', HodViews.delete_subject, name="delete_subject"),41 path('check_email_exist/', HodViews.check_email_exist, name="check_email_exist"),42 path('check_username_exist/', HodViews.check_username_exist, name="check_username_exist"),43 path('student_feedback_message/', HodViews.student_feedback_message, name="student_feedback_message"),44 path('student_feedback_message_reply/', HodViews.student_feedback_message_reply, name="student_feedback_message_reply"),45 path('staff_feedback_message/', HodViews.staff_feedback_message, name="staff_feedback_message"),46 path('staff_feedback_message_reply/', HodViews.staff_feedback_message_reply, name="staff_feedback_message_reply"),47 path('student_leave_view/', HodViews.student_leave_view, name="student_leave_view"),48 path('student_leave_approve/<leave_id>/', HodViews.student_leave_approve, name="student_leave_approve"),49 path('student_leave_reject/<leave_id>/', HodViews.student_leave_reject, name="student_leave_reject"),50 path('staff_leave_view/', HodViews.staff_leave_view, name="staff_leave_view"),51 path('staff_leave_approve/<leave_id>/', HodViews.staff_leave_approve, name="staff_leave_approve"),52 path('staff_leave_reject/<leave_id>/', HodViews.staff_leave_reject, name="staff_leave_reject"),53 path('admin_view_attendance/', HodViews.admin_view_attendance, name="admin_view_attendance"),54 path('admin_get_attendance_dates/', HodViews.admin_get_attendance_dates, name="admin_get_attendance_dates"),55 path('admin_get_attendance_student/', HodViews.admin_get_attendance_student, name="admin_get_attendance_student"),56 path('admin_profile/', HodViews.admin_profile, name="admin_profile"),57 path('admin_profile_update/', HodViews.admin_profile_update, name="admin_profile_update"),58 path('add_project/', HodViews.add_project, name="add_project"),59 path('add_project_save/', HodViews.add_project_save, name="add_project_save"),60 path('manage_project/', HodViews.manage_project, name="manage_project"),61 path('edit_project/<project_id>/', HodViews.edit_project, name="edit_project"),62 path('edit_project_save/', HodViews.edit_project_save, name="edit_project_save"),63 path('delete_project/<project_id>/', HodViews.delete_project, name="delete_project"),64 65 # URLS for Staff66 path('staff_home/', StaffViews.staff_home, name="staff_home"),67 path('staff_take_attendance/', StaffViews.staff_take_attendance, name="staff_take_attendance"),68 path('get_students/', StaffViews.get_students, name="get_students"),69 path('save_attendance_data/', StaffViews.save_attendance_data, name="save_attendance_data"),70 path('staff_update_attendance/', StaffViews.staff_update_attendance, name="staff_update_attendance"),71 path('get_attendance_dates/', StaffViews.get_attendance_dates, name="get_attendance_dates"),72 path('get_attendance_student/', StaffViews.get_attendance_student, name="get_attendance_student"),73 path('update_attendance_data/', StaffViews.update_attendance_data, name="update_attendance_data"),74 path('staff_apply_leave/', StaffViews.staff_apply_leave, name="staff_apply_leave"),75 path('staff_apply_leave_save/', StaffViews.staff_apply_leave_save, name="staff_apply_leave_save"),76 path('staff_feedback/', StaffViews.staff_feedback, name="staff_feedback"),77 path('staff_feedback_save/', StaffViews.staff_feedback_save, name="staff_feedback_save"),78 path('staff_profile/', StaffViews.staff_profile, name="staff_profile"),79 path('staff_profile_update/', StaffViews.staff_profile_update, name="staff_profile_update"),80 path('staff_add_result/', StaffViews.staff_add_result, name="staff_add_result"),81 path('staff_add_result_save/', StaffViews.staff_add_result_save, name="staff_add_result_save"),82 path('staff_project_view/', StaffViews.project_view, name="project_view"),83 path('staff_eachproject_view/<subject_id>/', StaffViews.eachproject_view, name="eachproject_view"),84 # URSL for Student85 path('student_home/', StudentViews.student_home, name="student_home"),86 path('student_view_attendance/', StudentViews.student_view_attendance, name="student_view_attendance"),87 path('student_view_attendance_post/', StudentViews.student_view_attendance_post, name="student_view_attendance_post"),88 path('student_apply_leave/', StudentViews.student_apply_leave, name="student_apply_leave"),89 path('student_apply_leave_save/', StudentViews.student_apply_leave_save, name="student_apply_leave_save"),90 path('student_feedback/', StudentViews.student_feedback, name="student_feedback"),91 path('student_feedback_save/', StudentViews.student_feedback_save, name="student_feedback_save"),92 path('student_profile/', StudentViews.student_profile, name="student_profile"),93 path('student_profile_update/', StudentViews.student_profile_update, name="student_profile_update"),94 path('student_view_result/', StudentViews.student_view_result, name="student_view_result"),...
Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!