Best Python code snippet using playwright-python
head_test.py
Source:head_test.py
...93 stirling_approx = z * np.log(z) - z + 0.5 * np.log(2. * np.pi * z)94 lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.)95 return sum(lpl)/len(lpl)96 def testPoissonWithLogits(self):97 head = head_lib.poisson_regression_head()98 labels = ((0.,), (1.,), (1.,))99 logits = ((0.,), (-1.,), (3.,))100 with ops.Graph().as_default(), session.Session():101 model_fn_ops = head.create_model_fn_ops(102 {},103 labels=labels,104 mode=model_fn.ModeKeys.TRAIN,105 train_op_fn=head_lib.no_op_train_fn,106 logits=logits)107 self._assert_output_alternatives(model_fn_ops)108 _assert_summary_tags(self, ["loss"])109 _assert_no_variables(self)110 loss = self._log_poisson_loss(logits, labels)111 _assert_metrics(self, loss, {"loss": loss}, model_fn_ops)112class RegressionHeadTest(test.TestCase):113 def _assert_output_alternatives(self, model_fn_ops):114 self.assertEquals({115 None: constants.ProblemType.LINEAR_REGRESSION116 }, {117 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)118 })119 # TODO(zakaria): test multilabel regression.120 def testRegressionWithLogits(self):121 head = head_lib.regression_head()122 with ops.Graph().as_default(), session.Session():123 model_fn_ops = head.create_model_fn_ops(124 {},125 labels=((0.,), (1.,), (1.,)),126 mode=model_fn.ModeKeys.TRAIN,127 train_op_fn=head_lib.no_op_train_fn,128 logits=((1.,), (1.,), (3.,)))129 self._assert_output_alternatives(model_fn_ops)130 _assert_summary_tags(self, ["loss"])131 _assert_no_variables(self)132 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)133 def testRegressionWithLogitFn(self):134 head = head_lib.regression_head(link_fn=math_ops.square)135 def _assert_preditions(test_case, expected_predictions, model_fn_ops):136 variables.initialize_local_variables().run()137 test_case.assertAllClose(expected_predictions,138 model_fn_ops.predictions["scores"].eval())139 with ops.Graph().as_default(), session.Session():140 model_fn_ops = head.create_model_fn_ops(141 {},142 labels=((0.,), (1.,), (1.,)),143 mode=model_fn.ModeKeys.TRAIN,144 train_op_fn=head_lib.no_op_train_fn,145 logits=((1.,), (1.,), (3.,)))146 self._assert_output_alternatives(model_fn_ops)147 _assert_summary_tags(self, ["loss"])148 _assert_no_variables(self)149 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)150 _assert_preditions(self, ([1.0, 1.0, 9.0]), model_fn_ops)151 def testRegressionWithInvalidLogits(self):152 head = head_lib.regression_head()153 with ops.Graph().as_default(), session.Session():154 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):155 head.create_model_fn_ops(156 {},157 labels=((0.,), (1.,), (1.,)),158 mode=model_fn.ModeKeys.TRAIN,159 train_op_fn=head_lib.no_op_train_fn,160 logits=((1., 1.), (1., 1.), (3., 1.)))161 def testRegressionWithLogitsInput(self):162 head = head_lib.regression_head()163 with ops.Graph().as_default(), session.Session():164 model_fn_ops = head.create_model_fn_ops(165 {},166 labels=((0.,), (1.,), (1.,)),167 mode=model_fn.ModeKeys.TRAIN,168 train_op_fn=head_lib.no_op_train_fn,169 logits_input=((0., 0.), (0., 0.), (0., 0.)))170 self._assert_output_alternatives(model_fn_ops)171 w = ("regression_head/logits/weights:0",172 "regression_head/logits/biases:0")173 _assert_variables(174 self, expected_global=w, expected_model=w, expected_trainable=w)175 variables.global_variables_initializer().run()176 _assert_summary_tags(self, ["loss"])177 _assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)178 def testRegressionWithLogitsAndLogitsInput(self):179 head = head_lib.regression_head()180 with ops.Graph().as_default(), session.Session():181 with self.assertRaisesRegexp(182 ValueError, "Both logits and logits_input supplied"):183 head.create_model_fn_ops(184 {},185 labels=((0.,), (1.,), (1.,)),186 mode=model_fn.ModeKeys.TRAIN,187 train_op_fn=head_lib.no_op_train_fn,188 logits_input=((0., 0.), (0., 0.), (0., 0.)),189 logits=((1.,), (1.,), (3.,)))190 def testRegressionEvalMode(self):191 head = head_lib.regression_head()192 with ops.Graph().as_default(), session.Session():193 model_fn_ops = head.create_model_fn_ops(194 {},195 labels=((1.,), (1.,), (3.,)),196 mode=model_fn.ModeKeys.EVAL,197 train_op_fn=head_lib.no_op_train_fn,198 logits=((0.,), (1.,), (1.,)))199 self._assert_output_alternatives(model_fn_ops)200 self.assertIsNone(model_fn_ops.train_op)201 _assert_no_variables(self)202 _assert_summary_tags(self, ["loss"])203 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)204 def testRegressionWithLabelName(self):205 label_name = "my_label"206 head = head_lib.regression_head(label_name=label_name)207 with ops.Graph().as_default(), session.Session():208 model_fn_ops = head.create_model_fn_ops(209 {},210 labels={label_name: ((0.,), (1.,), (1.,))},211 mode=model_fn.ModeKeys.TRAIN,212 train_op_fn=head_lib.no_op_train_fn,213 logits=((1.,), (1.,), (3.,)))214 self._assert_output_alternatives(model_fn_ops)215 _assert_no_variables(self)216 _assert_summary_tags(self, ["loss"])217 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)218 def testRegressionWithScalarWeights(self):219 head = head_lib.regression_head(weight_column_name="label_weight")220 with ops.Graph().as_default(), session.Session():221 weights = 2.222 labels = ((0.,), (1.,), (1.,))223 model_fn_ops = head.create_model_fn_ops(224 features={"label_weight": weights},225 labels=labels,226 mode=model_fn.ModeKeys.TRAIN,227 train_op_fn=head_lib.no_op_train_fn,228 logits=((1.,), (1.,), (3.,)))229 self._assert_output_alternatives(model_fn_ops)230 _assert_no_variables(self)231 _assert_summary_tags(self, ["loss"])232 _assert_metrics(self, (weights * 5.) / len(labels), {233 "loss": (weights * 5.) / (weights * len(labels))234 }, model_fn_ops)235 def testRegressionWith1DWeights(self):236 head = head_lib.regression_head(weight_column_name="label_weight")237 with ops.Graph().as_default(), session.Session():238 weights = (2., 5., 0.)239 labels = ((0.,), (1.,), (1.,))240 model_fn_ops = head.create_model_fn_ops(241 features={"label_weight": weights},242 labels=labels,243 mode=model_fn.ModeKeys.TRAIN,244 train_op_fn=head_lib.no_op_train_fn,245 logits=((1.,), (1.,), (3.,)))246 self._assert_output_alternatives(model_fn_ops)247 _assert_no_variables(self)248 _assert_summary_tags(self, ["loss"])249 _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},250 model_fn_ops)251 def testRegressionWith2DWeights(self):252 head = head_lib.regression_head(weight_column_name="label_weight")253 with ops.Graph().as_default(), session.Session():254 weights = ((2.,), (5.,), (0.,))255 labels = ((0.,), (1.,), (1.,))256 model_fn_ops = head.create_model_fn_ops(257 features={"label_weight": weights},258 labels=labels,259 mode=model_fn.ModeKeys.TRAIN,260 train_op_fn=head_lib.no_op_train_fn,261 logits=((1.,), (1.,), (3.,)))262 self._assert_output_alternatives(model_fn_ops)263 _assert_no_variables(self)264 _assert_summary_tags(self, ["loss"])265 _assert_metrics(self, 2. / len(labels), {"loss": 2. / np.sum(weights)},266 model_fn_ops)267 def testRegressionWithCenteredBias(self):268 head = head_lib.regression_head(enable_centered_bias=True)269 with ops.Graph().as_default(), session.Session():270 model_fn_ops = head.create_model_fn_ops(271 {},272 labels=((0.,), (1.,), (1.,)),273 mode=model_fn.ModeKeys.TRAIN,274 train_op_fn=head_lib.no_op_train_fn,275 logits=((1.,), (1.,), (3.,)))276 self._assert_output_alternatives(model_fn_ops)277 _assert_variables(278 self,279 expected_global=(280 "regression_head/centered_bias_weight:0",281 "regression_head/regression_head/centered_bias_weight/Adagrad:0",282 ),283 expected_trainable=("regression_head/centered_bias_weight:0",))284 variables.global_variables_initializer().run()285 _assert_summary_tags(self, [286 "loss",287 "regression_head/centered_bias/bias_0"288 ])289 _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)290 def testRegressionErrorInSparseTensorLabels(self):291 head = head_lib.regression_head()292 with ops.Graph().as_default():293 labels = sparse_tensor.SparseTensorValue(294 indices=((0, 0), (1, 0), (2, 0)),295 values=(0., 1., 1.),296 dense_shape=(3, 1))297 with self.assertRaisesRegexp(ValueError,298 "SparseTensor is not supported"):299 head.create_model_fn_ops(300 {},301 labels=labels,302 mode=model_fn.ModeKeys.TRAIN,303 train_op_fn=head_lib.no_op_train_fn,304 logits=((1.,), (1.,), (3.,)))305class MultiLabelHeadTest(test.TestCase):306 def _assert_output_alternatives(self, model_fn_ops):307 self.assertEquals({308 None: constants.ProblemType.CLASSIFICATION309 }, {310 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)311 })312 def setUp(self):313 self._logits = ((1., 0., 0.),)314 self._labels = ((0, 0, 1),)315 def _expected_eval_metrics(self, expected_loss):316 return {317 "accuracy": 1. / 3,318 "loss": expected_loss,319 "auc": 1. / 4,320 "auc/class0": 1.,321 "auc/class1": 1.,322 "auc/class2": 0.,323 "auc_precision_recall": 0.166667,324 "auc_precision_recall/class0": 0,325 "auc_precision_recall/class1": 0.,326 "auc_precision_recall/class2": 1.,327 "labels/actual_label_mean/class0": self._labels[0][0],328 "labels/actual_label_mean/class1": self._labels[0][1],329 "labels/actual_label_mean/class2": self._labels[0][2],330 "labels/logits_mean/class0": self._logits[0][0],331 "labels/logits_mean/class1": self._logits[0][1],332 "labels/logits_mean/class2": self._logits[0][2],333 "labels/prediction_mean/class0": self._logits[0][0],334 "labels/prediction_mean/class1": self._logits[0][1],335 "labels/prediction_mean/class2": self._logits[0][2],336 "labels/probability_mean/class0": _sigmoid(self._logits[0][0]),337 "labels/probability_mean/class1": _sigmoid(self._logits[0][1]),338 "labels/probability_mean/class2": _sigmoid(self._logits[0][2]),339 }340 def testMultiLabelWithLogits(self):341 n_classes = 3342 head = head_lib.multi_label_head(343 n_classes=n_classes, metric_class_ids=range(n_classes))344 with ops.Graph().as_default(), session.Session():345 model_fn_ops = head.create_model_fn_ops(346 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,347 logits=self._logits)348 self._assert_output_alternatives(model_fn_ops)349 _assert_no_variables(self)350 _assert_summary_tags(self, ["loss"])351 expected_loss = .89985204352 _assert_metrics(self, expected_loss,353 self._expected_eval_metrics(expected_loss), model_fn_ops)354 def testMultiLabelTwoClasses(self):355 n_classes = 2356 labels = ((0, 1),)357 logits = ((1., 0.),)358 head = head_lib.multi_label_head(359 n_classes=n_classes, metric_class_ids=range(n_classes))360 with ops.Graph().as_default(), session.Session():361 model_fn_ops = head.create_model_fn_ops(362 {}, model_fn.ModeKeys.TRAIN, labels=labels,363 train_op_fn=head_lib.no_op_train_fn, logits=logits)364 self._assert_output_alternatives(model_fn_ops)365 _assert_no_variables(self)366 _assert_summary_tags(self, ["loss"])367 expected_loss = 1.00320443368 _assert_metrics(self, expected_loss, {369 "accuracy": 0.,370 "auc": 0.,371 "loss": expected_loss,372 "auc/class0": 1.,373 "auc/class1": 0.,374 "labels/actual_label_mean/class0": labels[0][0],375 "labels/actual_label_mean/class1": labels[0][1],376 "labels/logits_mean/class0": logits[0][0],377 "labels/logits_mean/class1": logits[0][1],378 "labels/prediction_mean/class0": logits[0][0],379 "labels/prediction_mean/class1": logits[0][1],380 "labels/probability_mean/class0": _sigmoid(logits[0][0]),381 "labels/probability_mean/class1": _sigmoid(logits[0][1]),382 }, model_fn_ops)383 def testMultiLabelWithInvalidLogits(self):384 head = head_lib.multi_label_head(n_classes=len(self._labels[0]) + 1)385 with ops.Graph().as_default(), session.Session():386 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):387 head.create_model_fn_ops(388 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,389 logits=self._logits)390 def testMultiLabelWithLogitsInput(self):391 n_classes = 3392 head = head_lib.multi_label_head(393 n_classes=n_classes, metric_class_ids=range(n_classes))394 with ops.Graph().as_default(), session.Session():395 model_fn_ops = head.create_model_fn_ops(396 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,397 logits_input=((0., 0.),))398 self._assert_output_alternatives(model_fn_ops)399 w = ("multi_label_head/logits/weights:0",400 "multi_label_head/logits/biases:0")401 _assert_variables(402 self, expected_global=w, expected_model=w, expected_trainable=w)403 variables.global_variables_initializer().run()404 _assert_summary_tags(self, ["loss"])405 expected_loss = .69314718406 _assert_metrics(self, expected_loss, {407 "accuracy": 2. / 3,408 "auc": 2. / 4,409 "loss": expected_loss,410 "auc/class0": 1.,411 "auc/class1": 1.,412 "auc/class2": 0.,413 "labels/actual_label_mean/class0": self._labels[0][0],414 "labels/actual_label_mean/class1": self._labels[0][1],415 "labels/actual_label_mean/class2": self._labels[0][2],416 "labels/logits_mean/class0": 0.,417 "labels/logits_mean/class1": 0.,418 "labels/logits_mean/class2": 0.,419 "labels/prediction_mean/class0": 0.,420 "labels/prediction_mean/class1": 0.,421 "labels/prediction_mean/class2": 0.,422 "labels/probability_mean/class0": .5,423 "labels/probability_mean/class1": .5,424 "labels/probability_mean/class2": .5,425 }, model_fn_ops)426 def testMultiLabelWithLogitsAndLogitsInput(self):427 n_classes = 3428 head = head_lib.multi_label_head(429 n_classes=n_classes, metric_class_ids=range(n_classes))430 with ops.Graph().as_default(), session.Session():431 with self.assertRaisesRegexp(432 ValueError, "Both logits and logits_input supplied"):433 head.create_model_fn_ops(434 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,435 logits_input=((0., 0.),), logits=self._logits)436 def testMultiLabelEval(self):437 n_classes = 3438 head = head_lib.multi_label_head(439 n_classes=n_classes, metric_class_ids=range(n_classes))440 with ops.Graph().as_default(), session.Session():441 model_fn_ops = head.create_model_fn_ops(442 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,443 logits=self._logits)444 self._assert_output_alternatives(model_fn_ops)445 self.assertIsNone(model_fn_ops.train_op)446 _assert_no_variables(self)447 _assert_summary_tags(self, ["loss"])448 expected_loss = .89985204449 _assert_metrics(self, expected_loss,450 self._expected_eval_metrics(expected_loss), model_fn_ops)451 def testMultiClassEvalWithLargeLogits(self):452 n_classes = 3453 head = head_lib.multi_label_head(454 n_classes=n_classes, metric_class_ids=range(n_classes))455 logits = ((2., 0., -1),)456 with ops.Graph().as_default(), session.Session():457 # logloss: z:label, x:logit458 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))459 model_fn_ops = head.create_model_fn_ops(460 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,461 logits=logits)462 self._assert_output_alternatives(model_fn_ops)463 self.assertIsNone(model_fn_ops.train_op)464 _assert_no_variables(self)465 _assert_summary_tags(self, ["loss"])466 expected_loss = 1.377779467 expected_eval_metrics = {468 "accuracy": 1. / 3,469 "auc": 9.99999e-07,470 "loss": expected_loss,471 "auc/class0": 1.,472 "auc/class1": 1.,473 "auc/class2": 0.,474 "labels/actual_label_mean/class0": 0. / 1,475 "labels/actual_label_mean/class1": 0. / 1,476 "labels/actual_label_mean/class2": 1. / 1,477 "labels/logits_mean/class0": logits[0][0],478 "labels/logits_mean/class1": logits[0][1],479 "labels/logits_mean/class2": logits[0][2],480 "labels/prediction_mean/class0": 1,481 "labels/prediction_mean/class1": 0,482 "labels/prediction_mean/class2": 0,483 "labels/probability_mean/class0": _sigmoid(logits[0][0]),484 "labels/probability_mean/class1": _sigmoid(logits[0][1]),485 "labels/probability_mean/class2": _sigmoid(logits[0][2]),486 }487 _assert_metrics(self, expected_loss,488 expected_eval_metrics, model_fn_ops)489 def testMultiLabelInfer(self):490 n_classes = 3491 head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name")492 with ops.Graph().as_default(), session.Session():493 model_fn_ops = head.create_model_fn_ops(494 {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,495 logits=((1., 0., 0.), (0., 0., 1)))496 self.assertIsNone(model_fn_ops.train_op)497 _assert_no_variables(self)498 with session.Session():499 self.assertListEqual(500 [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0])501 self.assertItemsEqual(502 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))503 self.assertEqual(504 constants.ProblemType.CLASSIFICATION,505 model_fn_ops.output_alternatives["head_name"][0])506 predictions_for_serving = (507 model_fn_ops.output_alternatives["head_name"][1])508 self.assertIn("classes", six.iterkeys(predictions_for_serving))509 self.assertAllEqual(510 [[b"0", b"1", b"2"], [b"0", b"1", b"2"]],511 predictions_for_serving["classes"].eval())512 self.assertIn("probabilities", six.iterkeys(predictions_for_serving))513 self.assertAllClose(514 [[0.731059, 0.5, 0.5],515 [0.5, 0.5, 0.731059,]],516 predictions_for_serving["probabilities"].eval())517 def testMultiLabelWithLabelName(self):518 n_classes = 3519 label_name = "my_label"520 head = head_lib.multi_label_head(521 n_classes=n_classes,522 label_name=label_name,523 metric_class_ids=range(n_classes))524 with ops.Graph().as_default(), session.Session():525 model_fn_ops = head.create_model_fn_ops(526 {}, model_fn.ModeKeys.TRAIN, {label_name: self._labels},527 head_lib.no_op_train_fn, logits=self._logits)528 self._assert_output_alternatives(model_fn_ops)529 _assert_no_variables(self)530 _assert_summary_tags(self, ["loss"])531 expected_loss = .89985204532 _assert_metrics(self, expected_loss,533 self._expected_eval_metrics(expected_loss), model_fn_ops)534 def testMultiLabelWithScalarWeight(self):535 n_classes = 3536 head = head_lib.multi_label_head(537 n_classes=n_classes,538 weight_column_name="label_weight",539 metric_class_ids=range(n_classes))540 with ops.Graph().as_default(), session.Session():541 model_fn_ops = head.create_model_fn_ops(542 features={"label_weight": .1},543 labels=self._labels,544 mode=model_fn.ModeKeys.TRAIN,545 train_op_fn=head_lib.no_op_train_fn,546 logits=self._logits)547 self._assert_output_alternatives(model_fn_ops)548 _assert_no_variables(self)549 _assert_summary_tags(self, ["loss"])550 _assert_metrics(self, .089985214,551 self._expected_eval_metrics(.89985214), model_fn_ops)552 def testMultiLabelWith1DWeight(self):553 n_classes = 3554 head = head_lib.multi_label_head(555 n_classes=n_classes,556 weight_column_name="label_weight",557 metric_class_ids=range(n_classes))558 with ops.Graph().as_default(), session.Session():559 with self.assertRaisesRegexp(560 ValueError, "weights can not be broadcast to values"):561 head.create_model_fn_ops(562 features={"label_weight": (.1, .1, .1)},563 labels=self._labels,564 mode=model_fn.ModeKeys.TRAIN,565 train_op_fn=head_lib.no_op_train_fn,566 logits=self._logits)567 def testMultiLabelWith2DWeight(self):568 n_classes = 3569 head = head_lib.multi_label_head(570 n_classes=n_classes,571 weight_column_name="label_weight",572 metric_class_ids=range(n_classes))573 with ops.Graph().as_default(), session.Session():574 model_fn_ops = head.create_model_fn_ops(575 features={"label_weight": ((.1, .1, .1),)},576 labels=self._labels,577 mode=model_fn.ModeKeys.TRAIN,578 train_op_fn=head_lib.no_op_train_fn,579 logits=self._logits)580 self._assert_output_alternatives(model_fn_ops)581 _assert_no_variables(self)582 _assert_summary_tags(self, ["loss"])583 _assert_metrics(self, .089985214,584 self._expected_eval_metrics(.89985214), model_fn_ops)585 def testMultiLabelWithCustomLoss(self):586 n_classes = 3587 head = head_lib.multi_label_head(588 n_classes=n_classes,589 weight_column_name="label_weight",590 metric_class_ids=range(n_classes),591 loss_fn=_sigmoid_cross_entropy)592 with ops.Graph().as_default(), session.Session():593 model_fn_ops = head.create_model_fn_ops(594 features={"label_weight": .1},595 labels=self._labels,596 mode=model_fn.ModeKeys.TRAIN,597 train_op_fn=head_lib.no_op_train_fn,598 logits=self._logits)599 self._assert_output_alternatives(model_fn_ops)600 _assert_no_variables(self)601 _assert_summary_tags(self, ["loss"])602 expected_loss = .089985214603 _assert_metrics(self, expected_loss,604 self._expected_eval_metrics(expected_loss), model_fn_ops)605 def testMultiLabelWithCenteredBias(self):606 n_classes = 3607 head = head_lib.multi_label_head(608 n_classes=n_classes,609 enable_centered_bias=True,610 metric_class_ids=range(n_classes))611 with ops.Graph().as_default(), session.Session():612 model_fn_ops = head.create_model_fn_ops(613 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,614 logits=self._logits)615 self._assert_output_alternatives(model_fn_ops)616 _assert_variables(617 self,618 expected_global=(619 "multi_label_head/centered_bias_weight:0",620 ("multi_label_head/multi_label_head/centered_bias_weight/"621 "Adagrad:0"),),622 expected_trainable=("multi_label_head/centered_bias_weight:0",))623 variables.global_variables_initializer().run()624 _assert_summary_tags(self, (625 "loss",626 "multi_label_head/centered_bias/bias_0",627 "multi_label_head/centered_bias/bias_1",628 "multi_label_head/centered_bias/bias_2"629 ))630 expected_loss = .89985204631 _assert_metrics(self, expected_loss,632 self._expected_eval_metrics(expected_loss), model_fn_ops)633 def testMultiLabelSparseTensorLabels(self):634 n_classes = 3635 head = head_lib.multi_label_head(636 n_classes=n_classes, metric_class_ids=range(n_classes))637 with ops.Graph().as_default(), session.Session():638 labels = sparse_tensor.SparseTensorValue(639 indices=((0, 0),),640 values=(2,),641 dense_shape=(1, 1))642 model_fn_ops = head.create_model_fn_ops(643 features={},644 mode=model_fn.ModeKeys.TRAIN,645 labels=labels,646 train_op_fn=head_lib.no_op_train_fn,647 logits=self._logits)648 _assert_no_variables(self)649 _assert_summary_tags(self, ["loss"])650 expected_loss = .89985204651 _assert_metrics(self, expected_loss,652 self._expected_eval_metrics(expected_loss), model_fn_ops)653 def testMultiLabelSparseTensorLabelsTooFewClasses(self):654 n_classes = 3655 head = head_lib.multi_label_head(656 n_classes=n_classes, metric_class_ids=range(n_classes))657 # Set _logits_dimension (n_classes) to a lower value; if it's set to 1658 # upfront, the class throws an error during initialization.659 head._logits_dimension = 1660 with ops.Graph().as_default(), session.Session():661 labels = sparse_tensor.SparseTensorValue(662 indices=((0, 0),),663 values=(2,),664 dense_shape=(1, 1))665 with self.assertRaisesRegexp(ValueError,666 "Must set num_classes >= 2 when passing"):667 head.create_model_fn_ops(668 features={},669 labels=labels,670 mode=model_fn.ModeKeys.TRAIN,671 train_op_fn=head_lib.no_op_train_fn,672 logits=[0.])673class BinaryClassificationHeadTest(test.TestCase):674 def _assert_output_alternatives(self, model_fn_ops):675 self.assertEquals({676 None: constants.ProblemType.LOGISTIC_REGRESSION677 }, {678 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)679 })680 def setUp(self):681 self._logits = ((1.,), (1.,))682 self._labels = ((1.,), (0.,))683 def _expected_eval_metrics(self, expected_loss):684 label_mean = np.mean(self._labels)685 return {686 "accuracy": 1. / 2,687 "accuracy/baseline_label_mean": label_mean,688 "accuracy/threshold_0.500000_mean": 1. / 2,689 "auc": 1. / 2,690 "auc_precision_recall": 0.749999,691 "labels/actual_label_mean": label_mean,692 "labels/prediction_mean": .731059, # softmax693 "loss": expected_loss,694 "precision/positive_threshold_0.500000_mean": 1. / 2,695 "recall/positive_threshold_0.500000_mean": 1. / 1,696 }697 def testBinaryClassificationWithLogits(self):698 n_classes = 2699 head = head_lib.multi_class_head(n_classes=n_classes)700 with ops.Graph().as_default(), session.Session():701 # logloss: z:label, x:logit702 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))703 model_fn_ops = head.create_model_fn_ops(704 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,705 logits=self._logits)706 self._assert_output_alternatives(model_fn_ops)707 _assert_no_variables(self)708 _assert_summary_tags(self, ["loss"])709 expected_loss = .81326175710 _assert_metrics(self, expected_loss,711 self._expected_eval_metrics(expected_loss), model_fn_ops)712 def testBinaryClassificationWithInvalidLogits(self):713 head = head_lib.multi_class_head(n_classes=len(self._labels) + 1)714 with ops.Graph().as_default(), session.Session():715 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):716 head.create_model_fn_ops(717 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,718 logits=self._logits)719 def testBinaryClassificationWithLogitsInput(self):720 n_classes = 2721 head = head_lib.multi_class_head(n_classes=n_classes)722 with ops.Graph().as_default(), session.Session():723 # logloss: z:label, x:logit724 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))725 model_fn_ops = head.create_model_fn_ops(726 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,727 logits_input=((0., 0.), (0., 0.)))728 self._assert_output_alternatives(model_fn_ops)729 w = ("binary_logistic_head/logits/weights:0",730 "binary_logistic_head/logits/biases:0")731 _assert_variables(732 self, expected_global=w, expected_model=w, expected_trainable=w)733 variables.global_variables_initializer().run()734 _assert_summary_tags(self, ["loss"])735 expected_loss = .69314718736 label_mean = np.mean(self._labels)737 _assert_metrics(self, expected_loss, {738 "accuracy": 1. / 2,739 "accuracy/baseline_label_mean": label_mean,740 "accuracy/threshold_0.500000_mean": 1. / 2,741 "auc": 1. / 2,742 "labels/actual_label_mean": label_mean,743 "labels/prediction_mean": .5, # softmax744 "loss": expected_loss,745 "precision/positive_threshold_0.500000_mean": 0. / 2,746 "recall/positive_threshold_0.500000_mean": 0. / 1,747 }, model_fn_ops)748 def testBinaryClassificationWithLogitsAndLogitsInput(self):749 head = head_lib.multi_class_head(n_classes=2)750 with ops.Graph().as_default(), session.Session():751 with self.assertRaisesRegexp(752 ValueError, "Both logits and logits_input supplied"):753 head.create_model_fn_ops(754 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,755 logits_input=((0., 0.), (0., 0.)), logits=self._logits)756 def testBinaryClassificationEval(self):757 n_classes = 2758 head = head_lib.multi_class_head(n_classes=n_classes)759 with ops.Graph().as_default(), session.Session():760 # logloss: z:label, x:logit761 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))762 model_fn_ops = head.create_model_fn_ops(763 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,764 logits=self._logits)765 self._assert_output_alternatives(model_fn_ops)766 self.assertIsNone(model_fn_ops.train_op)767 _assert_no_variables(self)768 _assert_summary_tags(self, ["loss"])769 expected_loss = .81326175770 _assert_metrics(self, expected_loss,771 self._expected_eval_metrics(expected_loss), model_fn_ops)772 def testBinaryClassificationInfer(self):773 n_classes = 2774 head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name")775 with ops.Graph().as_default(), session.Session():776 # logloss: z:label, x:logit777 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))778 model_fn_ops = head.create_model_fn_ops(779 {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,780 logits=self._logits)781 self.assertIsNone(model_fn_ops.train_op)782 _assert_no_variables(self)783 with session.Session():784 self.assertListEqual(785 [1, 1], list(model_fn_ops.predictions["classes"].eval()))786 self.assertItemsEqual(787 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))788 self.assertEqual(789 constants.ProblemType.LOGISTIC_REGRESSION,790 model_fn_ops.output_alternatives["head_name"][0])791 predictions_for_serving = (792 model_fn_ops.output_alternatives["head_name"][1])793 self.assertIn("classes", six.iterkeys(predictions_for_serving))794 predicted_classes = predictions_for_serving["classes"].eval().tolist()795 self.assertListEqual(796 [b"0", b"1"], predicted_classes[0])797 self.assertIn("probabilities", six.iterkeys(predictions_for_serving))798 def testBinaryClassificationInferMode_withWeightColumn(self):799 n_classes = 2800 head = head_lib.multi_class_head(n_classes=n_classes,801 weight_column_name="label_weight")802 with ops.Graph().as_default(), session.Session():803 # logloss: z:label, x:logit804 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))805 model_fn_ops = head.create_model_fn_ops(806 # This is what is being tested, features should not have weight for807 # inference.808 {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,809 logits=self._logits)810 self._assert_output_alternatives(model_fn_ops)811 self.assertIsNone(model_fn_ops.train_op)812 _assert_no_variables(self)813 def testErrorInSparseTensorLabels(self):814 n_classes = 2815 head = head_lib.multi_class_head(n_classes=n_classes)816 with ops.Graph().as_default():817 labels = sparse_tensor.SparseTensorValue(818 indices=((0, 0), (1, 0), (2, 0)),819 values=(0, 1, 1),820 dense_shape=(3, 1))821 with self.assertRaisesRegexp(ValueError,822 "SparseTensor is not supported"):823 head.create_model_fn_ops(824 {},825 model_fn.ModeKeys.TRAIN,826 labels,827 head_lib.no_op_train_fn,828 logits=((1.,), (1.,), (3.,)))829 def testBinaryClassificationWithLabelName(self):830 label_name = "my_label"831 head = head_lib.multi_class_head(n_classes=2, label_name=label_name)832 with ops.Graph().as_default(), session.Session():833 # logloss: z:label, x:logit834 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))835 model_fn_ops = head.create_model_fn_ops(836 {},837 labels={label_name: self._labels},838 mode=model_fn.ModeKeys.TRAIN,839 train_op_fn=head_lib.no_op_train_fn,840 logits=self._logits)841 self._assert_output_alternatives(model_fn_ops)842 _assert_no_variables(self)843 _assert_summary_tags(self, ["loss"])844 expected_loss = .81326175845 _assert_metrics(self, expected_loss,846 self._expected_eval_metrics(expected_loss), model_fn_ops)847 def testBinaryClassificationWith1DWeights(self):848 n_classes = 2849 head = head_lib.multi_class_head(850 n_classes=n_classes, weight_column_name="label_weight")851 with ops.Graph().as_default(), session.Session():852 weights = (1., 0.)853 # logloss: z:label, x:logit854 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))855 model_fn_ops = head.create_model_fn_ops(856 features={"label_weight": weights},857 labels=self._labels,858 mode=model_fn.ModeKeys.TRAIN,859 train_op_fn=head_lib.no_op_train_fn,860 logits=self._logits)861 self._assert_output_alternatives(model_fn_ops)862 _assert_no_variables(self)863 _assert_summary_tags(self, ["loss"])864 expected_total_loss = .31326166865 _assert_metrics(866 self,867 expected_total_loss / len(weights),868 {869 "accuracy": 1. / 1,870 "accuracy/baseline_label_mean": 1. / 1,871 "accuracy/threshold_0.500000_mean": 1. / 1,872 "auc": 0. / 1,873 "labels/actual_label_mean": 1. / 1,874 "labels/prediction_mean": .731059, # softmax875 # eval loss is weighted loss divided by sum of weights.876 "loss": expected_total_loss,877 "precision/positive_threshold_0.500000_mean": 1. / 1,878 "recall/positive_threshold_0.500000_mean": 1. / 1,879 },880 model_fn_ops)881 def testBinaryClassificationWith2DWeights(self):882 n_classes = 2883 head = head_lib.multi_class_head(884 n_classes=n_classes, weight_column_name="label_weight")885 with ops.Graph().as_default(), session.Session():886 weights = ((1.,), (0.,))887 # logloss: z:label, x:logit888 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))889 model_fn_ops = head.create_model_fn_ops(890 features={"label_weight": weights},891 labels=self._labels,892 mode=model_fn.ModeKeys.TRAIN,893 train_op_fn=head_lib.no_op_train_fn,894 logits=self._logits)895 self._assert_output_alternatives(model_fn_ops)896 _assert_no_variables(self)897 _assert_summary_tags(self, ["loss"])898 expected_total_loss = .31326166899 _assert_metrics(900 self,901 expected_total_loss / len(weights),902 {903 "accuracy": 1. / 1,904 "accuracy/baseline_label_mean": 1. / 1,905 "accuracy/threshold_0.500000_mean": 1. / 1,906 "auc": 0. / 1,907 "labels/actual_label_mean": 1. / 1,908 "labels/prediction_mean": .731059, # softmax909 # eval loss is weighted loss divided by sum of weights.910 "loss": expected_total_loss,911 "precision/positive_threshold_0.500000_mean": 1. / 1,912 "recall/positive_threshold_0.500000_mean": 1. / 1,913 },914 model_fn_ops)915 def testBinaryClassificationWithCustomLoss(self):916 head = head_lib.multi_class_head(917 n_classes=2, weight_column_name="label_weight",918 loss_fn=_sigmoid_cross_entropy)919 with ops.Graph().as_default(), session.Session():920 weights = ((.2,), (0.,))921 model_fn_ops = head.create_model_fn_ops(922 features={"label_weight": weights},923 labels=self._labels,924 mode=model_fn.ModeKeys.TRAIN,925 train_op_fn=head_lib.no_op_train_fn,926 logits=self._logits)927 self._assert_output_alternatives(model_fn_ops)928 _assert_no_variables(self)929 _assert_summary_tags(self, ["loss"])930 # logloss: z:label, x:logit931 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))932 # expected_loss is (total_weighted_loss)/1 since there is 1 nonzero933 # weight.934 expected_loss = 0.062652342935 _assert_metrics(936 self,937 expected_loss,938 {939 "accuracy": 1. / 1,940 "accuracy/baseline_label_mean": 1. / 1,941 "accuracy/threshold_0.500000_mean": 1. / 1,942 "auc": 0. / 1,943 "labels/actual_label_mean": 1. / 1,944 "labels/prediction_mean": .731059, # softmax945 "loss": expected_loss,946 "precision/positive_threshold_0.500000_mean": 1. / 1,947 "recall/positive_threshold_0.500000_mean": 1. / 1,948 },949 model_fn_ops)950 def testBinaryClassificationWithCenteredBias(self):951 head = head_lib.multi_class_head(n_classes=2, enable_centered_bias=True)952 with ops.Graph().as_default(), session.Session():953 # logloss: z:label, x:logit954 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))955 model_fn_ops = head.create_model_fn_ops(956 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,957 logits=self._logits)958 self._assert_output_alternatives(model_fn_ops)959 _assert_variables(960 self,961 expected_global=(962 "binary_logistic_head/centered_bias_weight:0",963 ("binary_logistic_head/binary_logistic_head/centered_bias_weight/"964 "Adagrad:0"),),965 expected_trainable=("binary_logistic_head/centered_bias_weight:0",))966 variables.global_variables_initializer().run()967 _assert_summary_tags(self, [968 "loss",969 "binary_logistic_head/centered_bias/bias_0"970 ])971 expected_loss = .81326175972 _assert_metrics(self, expected_loss,973 self._expected_eval_metrics(expected_loss), model_fn_ops)974class MultiClassHeadTest(test.TestCase):975 def _assert_output_alternatives(self, model_fn_ops):976 self.assertEquals({977 None: constants.ProblemType.CLASSIFICATION978 }, {979 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)980 })981 def setUp(self):982 self._logits = ((1., 0., 0.),)983 self._labels = ((2,),)984 def _expected_eval_metrics(self, expected_loss):985 return {986 "accuracy": 0.,987 "loss": expected_loss,988 "labels/actual_label_mean/class0": 0. / 1,989 "labels/actual_label_mean/class1": 0. / 1,990 "labels/actual_label_mean/class2": 1. / 1,991 "labels/logits_mean/class0": self._logits[0][0],992 "labels/logits_mean/class1": self._logits[0][1],993 "labels/logits_mean/class2": self._logits[0][2],994 "labels/prediction_mean/class0": self._logits[0][0],995 "labels/prediction_mean/class1": self._logits[0][1],996 "labels/prediction_mean/class2": self._logits[0][2],997 "labels/probability_mean/class0": 0.576117, # softmax998 "labels/probability_mean/class1": 0.211942, # softmax999 "labels/probability_mean/class2": 0.211942, # softmax1000 }1001 def testMultiClassWithLogits(self):1002 n_classes = 31003 head = head_lib.multi_class_head(1004 n_classes=n_classes, metric_class_ids=range(n_classes))1005 with ops.Graph().as_default(), session.Session():1006 # logloss: z:label, x:logit1007 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1008 model_fn_ops = head.create_model_fn_ops(1009 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1010 logits=self._logits)1011 self._assert_output_alternatives(model_fn_ops)1012 _assert_no_variables(self)1013 _assert_summary_tags(self, ["loss"])1014 expected_loss = 1.55144471015 _assert_metrics(self, expected_loss,1016 self._expected_eval_metrics(expected_loss), model_fn_ops)1017 def testMultiClassWithInvalidLogits(self):1018 head = head_lib.multi_class_head(n_classes=len(self._logits[0]) + 1)1019 with ops.Graph().as_default(), session.Session():1020 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):1021 head.create_model_fn_ops(1022 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1023 logits=self._logits)1024 def testMultiClassWithNoneTrainOpFnInTrain(self):1025 head = head_lib.multi_class_head(n_classes=3)1026 with ops.Graph().as_default(), session.Session():1027 with self.assertRaisesRegexp(1028 ValueError, "train_op_fn can not be None in TRAIN mode"):1029 head.create_model_fn_ops(1030 {}, model_fn.ModeKeys.TRAIN, self._labels,1031 train_op_fn=None,1032 logits=self._logits)1033 def testMultiClassWithLogitsInput(self):1034 n_classes = 31035 head = head_lib.multi_class_head(1036 n_classes=n_classes, metric_class_ids=range(n_classes))1037 with ops.Graph().as_default(), session.Session():1038 # logloss: z:label, x:logit1039 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1040 model_fn_ops = head.create_model_fn_ops(1041 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1042 logits_input=((0., 0.),))1043 self._assert_output_alternatives(model_fn_ops)1044 w = ("multi_class_head/logits/weights:0",1045 "multi_class_head/logits/biases:0")1046 _assert_variables(1047 self, expected_global=w, expected_model=w, expected_trainable=w)1048 variables.global_variables_initializer().run()1049 _assert_summary_tags(self, ["loss"])1050 expected_loss = 1.09861231051 _assert_metrics(self, expected_loss, {1052 "accuracy": 0.,1053 "loss": expected_loss,1054 "labels/actual_label_mean/class0": 0. / 1,1055 "labels/actual_label_mean/class1": 0. / 1,1056 "labels/actual_label_mean/class2": 1. / 1,1057 "labels/logits_mean/class0": 0.,1058 "labels/logits_mean/class1": 0.,1059 "labels/logits_mean/class2": 0.,1060 "labels/prediction_mean/class0": 1.,1061 "labels/prediction_mean/class1": 0.,1062 "labels/prediction_mean/class2": 0.,1063 "labels/probability_mean/class0": 0.333333, # softmax1064 "labels/probability_mean/class1": 0.333333, # softmax1065 "labels/probability_mean/class2": 0.333333, # softmax1066 }, model_fn_ops)1067 def testMultiClassWithLogitsAndLogitsInput(self):1068 n_classes = 31069 head = head_lib.multi_class_head(1070 n_classes=n_classes, metric_class_ids=range(n_classes))1071 with ops.Graph().as_default(), session.Session():1072 with self.assertRaisesRegexp(1073 ValueError, "Both logits and logits_input supplied"):1074 head.create_model_fn_ops(1075 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1076 logits_input=((0., 0.),), logits=self._logits)1077 def testMultiClassEnableCenteredBias(self):1078 n_classes = 31079 head = head_lib.multi_class_head(1080 n_classes=n_classes, enable_centered_bias=True)1081 with ops.Graph().as_default(), session.Session():1082 # logloss: z:label, x:logit1083 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1084 model_fn_ops = head.create_model_fn_ops(1085 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1086 logits=self._logits)1087 self._assert_output_alternatives(model_fn_ops)1088 _assert_variables(1089 self,1090 expected_global=(1091 "multi_class_head/centered_bias_weight:0",1092 ("multi_class_head/multi_class_head/centered_bias_weight/"1093 "Adagrad:0"),1094 ),1095 expected_trainable=("multi_class_head/centered_bias_weight:0",))1096 variables.global_variables_initializer().run()1097 _assert_summary_tags(self,1098 ["loss",1099 "multi_class_head/centered_bias/bias_0",1100 "multi_class_head/centered_bias/bias_1",1101 "multi_class_head/centered_bias/bias_2"])1102 def testMultiClassEval(self):1103 n_classes = 31104 head = head_lib.multi_class_head(1105 n_classes=n_classes, metric_class_ids=range(n_classes))1106 with ops.Graph().as_default(), session.Session():1107 # logloss: z:label, x:logit1108 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1109 model_fn_ops = head.create_model_fn_ops(1110 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,1111 logits=self._logits)1112 self._assert_output_alternatives(model_fn_ops)1113 self.assertIsNone(model_fn_ops.train_op)1114 _assert_no_variables(self)1115 _assert_summary_tags(self, ["loss"])1116 expected_loss = 1.55144471117 _assert_metrics(self, expected_loss,1118 self._expected_eval_metrics(expected_loss), model_fn_ops)1119 def testMultiClassEvalModeWithLargeLogits(self):1120 n_classes = 31121 head = head_lib.multi_class_head(1122 n_classes=n_classes, metric_class_ids=range(n_classes))1123 logits = ((2., 0., -1),)1124 with ops.Graph().as_default(), session.Session():1125 # logloss: z:label, x:logit1126 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1127 model_fn_ops = head.create_model_fn_ops(1128 {}, model_fn.ModeKeys.EVAL, self._labels, head_lib.no_op_train_fn,1129 logits=logits)1130 self._assert_output_alternatives(model_fn_ops)1131 self.assertIsNone(model_fn_ops.train_op)1132 _assert_no_variables(self)1133 _assert_summary_tags(self, ["loss"])1134 expected_loss = 3.16984611135 expected_eval_metrics = {1136 "accuracy": 0.,1137 "loss": expected_loss,1138 "labels/actual_label_mean/class0": 0. / 1,1139 "labels/actual_label_mean/class1": 0. / 1,1140 "labels/actual_label_mean/class2": 1. / 1,1141 "labels/logits_mean/class0": logits[0][0],1142 "labels/logits_mean/class1": logits[0][1],1143 "labels/logits_mean/class2": logits[0][2],1144 "labels/prediction_mean/class0": 1,1145 "labels/prediction_mean/class1": 0,1146 "labels/prediction_mean/class2": 0,1147 "labels/probability_mean/class0": 0.843795, # softmax1148 "labels/probability_mean/class1": 0.114195, # softmax1149 "labels/probability_mean/class2": 0.0420101, # softmax1150 }1151 _assert_metrics(self, expected_loss,1152 expected_eval_metrics, model_fn_ops)1153 def testMultiClassWithScalarWeight(self):1154 n_classes = 31155 head = head_lib.multi_class_head(1156 n_classes=n_classes,1157 weight_column_name="label_weight",1158 metric_class_ids=range(n_classes))1159 with ops.Graph().as_default(), session.Session():1160 weight = .11161 # logloss: z:label, x:logit1162 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1163 model_fn_ops = head.create_model_fn_ops(1164 features={"label_weight": weight},1165 labels=self._labels,1166 mode=model_fn.ModeKeys.TRAIN,1167 train_op_fn=head_lib.no_op_train_fn,1168 logits=self._logits)1169 self._assert_output_alternatives(model_fn_ops)1170 _assert_no_variables(self)1171 _assert_summary_tags(self, ["loss"])1172 expected_loss = 1.55144471173 _assert_metrics(self, expected_loss * weight,1174 self._expected_eval_metrics(expected_loss), model_fn_ops)1175 def testMultiClassWith1DWeight(self):1176 n_classes = 31177 head = head_lib.multi_class_head(1178 n_classes=n_classes,1179 weight_column_name="label_weight",1180 metric_class_ids=range(n_classes))1181 with ops.Graph().as_default(), session.Session():1182 weight = .11183 weights = (weight,)1184 # logloss: z:label, x:logit1185 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1186 model_fn_ops = head.create_model_fn_ops(1187 features={"label_weight": weights},1188 labels=self._labels,1189 mode=model_fn.ModeKeys.TRAIN,1190 train_op_fn=head_lib.no_op_train_fn,1191 logits=self._logits)1192 self._assert_output_alternatives(model_fn_ops)1193 _assert_no_variables(self)1194 _assert_summary_tags(self, ["loss"])1195 expected_loss = 1.55144471196 _assert_metrics(self, expected_loss * weight,1197 self._expected_eval_metrics(expected_loss), model_fn_ops)1198 def testMultiClassWith2DWeight(self):1199 n_classes = 31200 head = head_lib.multi_class_head(1201 n_classes=n_classes,1202 weight_column_name="label_weight",1203 metric_class_ids=range(n_classes))1204 with ops.Graph().as_default(), session.Session():1205 weight = .11206 weights = ((weight,),)1207 # logloss: z:label, x:logit1208 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1209 model_fn_ops = head.create_model_fn_ops(1210 features={"label_weight": weights},1211 labels=self._labels,1212 mode=model_fn.ModeKeys.TRAIN,1213 train_op_fn=head_lib.no_op_train_fn,1214 logits=self._logits)1215 self._assert_output_alternatives(model_fn_ops)1216 _assert_no_variables(self)1217 _assert_summary_tags(self, ["loss"])1218 expected_loss = 1.55144471219 _assert_metrics(self, expected_loss * weight,1220 self._expected_eval_metrics(expected_loss), model_fn_ops)1221 def testMultiClassWithCustomLoss(self):1222 n_classes = 31223 head = head_lib.multi_class_head(1224 n_classes=n_classes,1225 weight_column_name="label_weight",1226 metric_class_ids=range(n_classes),1227 loss_fn=losses_lib.sparse_softmax_cross_entropy)1228 with ops.Graph().as_default(), session.Session():1229 weight = .11230 # logloss: z:label, x:logit1231 # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))1232 model_fn_ops = head.create_model_fn_ops(1233 features={"label_weight": weight},1234 labels=self._labels,1235 mode=model_fn.ModeKeys.TRAIN,1236 train_op_fn=head_lib.no_op_train_fn,1237 logits=self._logits)1238 self._assert_output_alternatives(model_fn_ops)1239 _assert_no_variables(self)1240 _assert_summary_tags(self, ["loss"])1241 expected_loss = 1.5514447 * weight1242 _assert_metrics(self, expected_loss,1243 self._expected_eval_metrics(expected_loss), model_fn_ops)1244 def testMultiClassInfer(self):1245 n_classes = 31246 head = head_lib._multi_class_head(1247 n_classes=n_classes,1248 head_name="head_name")1249 with ops.Graph().as_default():1250 model_fn_ops = head.create_model_fn_ops(1251 features={},1252 mode=model_fn.ModeKeys.INFER,1253 train_op_fn=head_lib.no_op_train_fn,1254 logits=((1., 0., 0.), (0., 0., 1.),))1255 with session.Session():1256 lookup_ops.tables_initializer().run()1257 self.assertAllEqual(1258 [0, 2],1259 model_fn_ops.predictions["classes"].eval())1260 self.assertItemsEqual(1261 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))1262 self.assertEqual(1263 constants.ProblemType.CLASSIFICATION,1264 model_fn_ops.output_alternatives["head_name"][0])1265 predictions_for_serving = (1266 model_fn_ops.output_alternatives["head_name"][1])1267 self.assertIn("classes", six.iterkeys(predictions_for_serving))1268 self.assertAllEqual(1269 [[b"0", b"1", b"2"], [b"0", b"1", b"2"]],1270 predictions_for_serving["classes"].eval())1271 self.assertIn("probabilities", six.iterkeys(predictions_for_serving))1272 self.assertAllClose(1273 [[0.576117, 0.2119416, 0.2119416],1274 [0.2119416, 0.2119416, 0.576117]],1275 predictions_for_serving["probabilities"].eval())1276 def testInvalidNClasses(self):1277 for n_classes in (None, -1, 0, 1):1278 with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):1279 head_lib.multi_class_head(n_classes=n_classes)1280 def testMultiClassWithLabelKeysInvalidShape(self):1281 with self.assertRaisesRegexp(1282 ValueError, "Length of label_keys must equal n_classes"):1283 head_lib._multi_class_head(1284 n_classes=3, label_keys=("key0", "key1"))1285 def testMultiClassWithLabelKeysTwoClasses(self):1286 with self.assertRaisesRegexp(1287 ValueError, "label_keys is not supported for n_classes=2"):1288 head_lib._multi_class_head(1289 n_classes=2, label_keys=("key0", "key1"))1290 def testMultiClassWithLabelKeysInfer(self):1291 n_classes = 31292 label_keys = ("key0", "key1", "key2")1293 head = head_lib._multi_class_head(1294 n_classes=n_classes, label_keys=label_keys,1295 metric_class_ids=range(n_classes),1296 head_name="head_name")1297 with ops.Graph().as_default():1298 model_fn_ops = head.create_model_fn_ops(1299 features={},1300 mode=model_fn.ModeKeys.INFER,1301 train_op_fn=head_lib.no_op_train_fn,1302 logits=((1., 0., 0.), (0., 0., 1.),))1303 with session.Session():1304 lookup_ops.tables_initializer().run()1305 self.assertAllEqual(1306 [b"key0", b"key2"],1307 model_fn_ops.predictions["classes"].eval())1308 self.assertItemsEqual(1309 ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))1310 self.assertEqual(1311 constants.ProblemType.CLASSIFICATION,1312 model_fn_ops.output_alternatives["head_name"][0])1313 predictions_for_serving = (1314 model_fn_ops.output_alternatives["head_name"][1])1315 self.assertIn("classes", six.iterkeys(predictions_for_serving))1316 self.assertAllEqual(1317 [[b"key0", b"key1", b"key2"], [b"key0", b"key1", b"key2"]],1318 predictions_for_serving["classes"].eval())1319 self.assertIn("probabilities", six.iterkeys(predictions_for_serving))1320 self.assertAllClose(1321 [[0.576117, 0.2119416, 0.2119416],1322 [0.2119416, 0.2119416, 0.576117]],1323 predictions_for_serving["probabilities"].eval())1324 def testMultiClassWithLabelKeysEvalAccuracy0(self):1325 n_classes = 31326 label_keys = ("key0", "key1", "key2")1327 head = head_lib._multi_class_head(1328 n_classes=n_classes,1329 label_keys=label_keys)1330 with ops.Graph().as_default():1331 model_fn_ops = head.create_model_fn_ops(1332 features={},1333 mode=model_fn.ModeKeys.EVAL,1334 labels=("key2",),1335 train_op_fn=head_lib.no_op_train_fn,1336 logits=((1., 0., 0.),))1337 with session.Session():1338 lookup_ops.tables_initializer().run()1339 self.assertIsNone(model_fn_ops.train_op)1340 _assert_no_variables(self)1341 _assert_summary_tags(self, ["loss"])1342 expected_loss = 1.55144471343 expected_eval_metrics = {1344 "accuracy": 0.,1345 "loss": expected_loss,1346 }1347 _assert_metrics(self, expected_loss,1348 expected_eval_metrics, model_fn_ops)1349 def testMultiClassWithLabelKeysEvalAccuracy1(self):1350 n_classes = 31351 label_keys = ("key0", "key1", "key2")1352 head = head_lib._multi_class_head(1353 n_classes=n_classes,1354 label_keys=label_keys)1355 with ops.Graph().as_default():1356 model_fn_ops = head.create_model_fn_ops(1357 features={},1358 mode=model_fn.ModeKeys.EVAL,1359 labels=("key2",),1360 train_op_fn=head_lib.no_op_train_fn,1361 logits=((0., 0., 1.),))1362 with session.Session():1363 lookup_ops.tables_initializer().run()1364 self.assertIsNone(model_fn_ops.train_op)1365 _assert_no_variables(self)1366 _assert_summary_tags(self, ["loss"])1367 expected_loss = 0.55144471368 expected_eval_metrics = {1369 "accuracy": 1.,1370 "loss": expected_loss,1371 }1372 _assert_metrics(self, expected_loss,1373 expected_eval_metrics, model_fn_ops)1374class BinarySvmHeadTest(test.TestCase):1375 def _assert_output_alternatives(self, model_fn_ops):1376 self.assertEquals({1377 None: constants.ProblemType.LOGISTIC_REGRESSION1378 }, {1379 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)1380 })1381 def setUp(self):1382 # Prediction for first example is in the right side of the hyperplane1383 # (i.e., < 0) but it is within the [-1,1] margin. There is a 0.5 loss1384 # incurred by this example. The 2nd prediction is outside the margin so it1385 # incurs no loss at all.1386 self._predictions = ((-.5,), (1.2,))1387 self._labels = (0, 1)1388 self._expected_losses = (.5, 0.)1389 def testBinarySVMWithLogits(self):1390 head = head_lib.binary_svm_head()1391 with ops.Graph().as_default(), session.Session():1392 model_fn_ops = head.create_model_fn_ops(1393 {},1394 model_fn.ModeKeys.TRAIN,1395 self._labels,1396 head_lib.no_op_train_fn,1397 logits=self._predictions)1398 self._assert_output_alternatives(model_fn_ops)1399 _assert_no_variables(self)1400 _assert_summary_tags(self, ["loss"])1401 expected_loss = np.average(self._expected_losses)1402 _assert_metrics(self, expected_loss, {1403 "accuracy": 1.,1404 "loss": expected_loss,1405 }, model_fn_ops)1406 def testBinarySVMWithInvalidLogits(self):1407 head = head_lib.binary_svm_head()1408 with ops.Graph().as_default(), session.Session():1409 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):1410 head.create_model_fn_ops(1411 {}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,1412 logits=np.ones((2, 2)))1413 def testBinarySVMWithLogitsInput(self):1414 head = head_lib.binary_svm_head()1415 with ops.Graph().as_default(), session.Session():1416 model_fn_ops = head.create_model_fn_ops(1417 {},1418 model_fn.ModeKeys.TRAIN,1419 self._labels,1420 head_lib.no_op_train_fn,1421 logits_input=((0., 0.), (0., 0.)))1422 self._assert_output_alternatives(model_fn_ops)1423 w = ("binary_svm_head/logits/weights:0",1424 "binary_svm_head/logits/biases:0")1425 _assert_variables(1426 self, expected_global=w, expected_model=w, expected_trainable=w)1427 variables.global_variables_initializer().run()1428 _assert_summary_tags(self, ["loss"])1429 expected_loss = 1.1430 _assert_metrics(self, expected_loss, {1431 "accuracy": .5,1432 "loss": expected_loss,1433 }, model_fn_ops)1434 def testBinarySVMWithLogitsAndLogitsInput(self):1435 head = head_lib.binary_svm_head()1436 with ops.Graph().as_default(), session.Session():1437 with self.assertRaisesRegexp(1438 ValueError, "Both logits and logits_input supplied"):1439 head.create_model_fn_ops(1440 {},1441 model_fn.ModeKeys.TRAIN,1442 self._labels,1443 head_lib.no_op_train_fn,1444 logits_input=((0., 0.), (0., 0.)),1445 logits=self._predictions)1446 def testBinarySVMEvalMode(self):1447 head = head_lib.binary_svm_head()1448 with ops.Graph().as_default(), session.Session():1449 model_fn_ops = head.create_model_fn_ops(1450 {},1451 model_fn.ModeKeys.EVAL,1452 self._labels,1453 head_lib.no_op_train_fn,1454 logits=self._predictions)1455 self._assert_output_alternatives(model_fn_ops)1456 self.assertIsNone(model_fn_ops.train_op)1457 _assert_no_variables(self)1458 _assert_summary_tags(self, ["loss"])1459 expected_loss = np.average(self._expected_losses)1460 _assert_metrics(self, expected_loss, {1461 "accuracy": 1.,1462 "loss": expected_loss,1463 }, model_fn_ops)1464 def testBinarySVMWithLabelName(self):1465 label_name = "my_label"1466 head = head_lib.binary_svm_head(label_name=label_name)1467 with ops.Graph().as_default(), session.Session():1468 model_fn_ops = head.create_model_fn_ops(1469 {},1470 model_fn.ModeKeys.TRAIN,1471 {label_name: self._labels},1472 head_lib.no_op_train_fn,1473 logits=self._predictions)1474 self._assert_output_alternatives(model_fn_ops)1475 _assert_no_variables(self)1476 _assert_summary_tags(self, ["loss"])1477 expected_loss = np.average(self._expected_losses)1478 _assert_metrics(self, expected_loss, {1479 "accuracy": 1.,1480 "loss": expected_loss,1481 }, model_fn_ops)1482 def testBinarySVMWith1DWeights(self):1483 head = head_lib.binary_svm_head(weight_column_name="weights")1484 with ops.Graph().as_default(), session.Session():1485 weights = (7., 11.)1486 model_fn_ops = head.create_model_fn_ops(1487 # We have to add an extra dim here for weights broadcasting to work.1488 features={"weights": weights},1489 mode=model_fn.ModeKeys.TRAIN,1490 labels=self._labels,1491 train_op_fn=head_lib.no_op_train_fn,1492 logits=self._predictions)1493 self._assert_output_alternatives(model_fn_ops)1494 _assert_no_variables(self)1495 _assert_summary_tags(self, ["loss"])1496 expected_weighted_losses = np.multiply(weights, self._expected_losses)1497 _assert_metrics(self, np.mean(expected_weighted_losses), {1498 "accuracy": 1.,1499 "loss": np.sum(expected_weighted_losses) / np.sum(weights),1500 }, model_fn_ops)1501 def testBinarySVMWith2DWeights(self):1502 head = head_lib.binary_svm_head(weight_column_name="weights")1503 with ops.Graph().as_default(), session.Session():1504 weights = (7., 11.)1505 model_fn_ops = head.create_model_fn_ops(1506 # We have to add an extra dim here for weights broadcasting to work.1507 features={"weights": tuple([(w,) for w in weights])},1508 mode=model_fn.ModeKeys.TRAIN,1509 labels=self._labels,1510 train_op_fn=head_lib.no_op_train_fn,1511 logits=self._predictions)1512 self._assert_output_alternatives(model_fn_ops)1513 _assert_no_variables(self)1514 _assert_summary_tags(self, ["loss"])1515 expected_weighted_losses = np.multiply(weights, self._expected_losses)1516 _assert_metrics(self, np.mean(expected_weighted_losses), {1517 "accuracy": 1.,1518 "loss": np.sum(expected_weighted_losses) / np.sum(weights),1519 }, model_fn_ops)1520 def testBinarySVMWithCenteredBias(self):1521 head = head_lib.binary_svm_head(enable_centered_bias=True)1522 with ops.Graph().as_default(), session.Session():1523 model_fn_ops = head.create_model_fn_ops(1524 {},1525 model_fn.ModeKeys.TRAIN,1526 self._labels,1527 head_lib.no_op_train_fn,1528 logits=self._predictions)1529 self._assert_output_alternatives(model_fn_ops)1530 _assert_variables(1531 self,1532 expected_global=(1533 "binary_svm_head/centered_bias_weight:0",1534 ("binary_svm_head/binary_svm_head/centered_bias_weight/"1535 "Adagrad:0"),1536 ),1537 expected_trainable=("binary_svm_head/centered_bias_weight:0",))1538 variables.global_variables_initializer().run()1539 _assert_summary_tags(self, [1540 "loss",1541 "binary_svm_head/centered_bias/bias_0"1542 ])1543 expected_loss = np.average(self._expected_losses)1544 _assert_metrics(self, expected_loss, {1545 "accuracy": 1.,1546 "loss": expected_loss,1547 }, model_fn_ops)1548class LossOnlyHead(test.TestCase):1549 def testNoPredictionsAndNoMetrics(self):1550 head = head_lib.loss_only_head(lambda: 1, head_name="const")1551 model_fn_ops = head.create_model_fn_ops(1552 features={},1553 mode=model_fn.ModeKeys.TRAIN,1554 train_op_fn=head_lib.no_op_train_fn)1555 self.assertDictEqual(model_fn_ops.predictions, {})1556 self.assertDictEqual(model_fn_ops.eval_metric_ops, {})1557 self.assertIsNotNone(model_fn_ops.loss)1558 with session.Session() as sess:1559 self.assertEqual(1, sess.run(model_fn_ops.loss))1560class MultiHeadTest(test.TestCase):1561 def testInvalidHeads(self):1562 named_head = head_lib.multi_class_head(1563 n_classes=3, label_name="label", head_name="head1")1564 unnamed_head = head_lib.multi_class_head(1565 n_classes=4, label_name="label")1566 with self.assertRaisesRegexp(ValueError, "must have names"):1567 head_lib.multi_head((named_head, unnamed_head))1568 def testTrainWithNoneTrainOpFn(self):1569 head1 = head_lib.multi_class_head(1570 n_classes=3, label_name="label1", head_name="head1")1571 head2 = head_lib.multi_class_head(1572 n_classes=4, label_name="label2", head_name="head2")1573 head = head_lib.multi_head((head1, head2))1574 labels = {1575 "label1": (1,),1576 "label2": (1,)1577 }1578 with self.assertRaisesRegexp(1579 ValueError, "train_op_fn can not be None in TRAIN mode"):1580 head.create_model_fn_ops(1581 features={"weights": (2.0, 10.0)},1582 labels=labels,1583 mode=model_fn.ModeKeys.TRAIN,1584 train_op_fn=None,1585 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1586 def testTrain_withNoHeadWeights(self):1587 head1 = head_lib.multi_class_head(1588 n_classes=3, label_name="label1", head_name="head1")1589 head2 = head_lib.multi_class_head(1590 n_classes=4, label_name="label2", head_name="head2")1591 head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")1592 head = head_lib.multi_head((head1, head2, head3))1593 labels = {1594 "label1": (1,),1595 "label2": (1,)1596 }1597 model_fn_ops = head.create_model_fn_ops(1598 features={"weights": (2.0, 10.0)},1599 labels=labels,1600 mode=model_fn.ModeKeys.TRAIN,1601 train_op_fn=head_lib.no_op_train_fn,1602 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1603 self.assertIsNone(model_fn_ops.predictions)1604 self.assertIsNotNone(model_fn_ops.loss)1605 self.assertIsNotNone(model_fn_ops.train_op)1606 self.assertTrue(model_fn_ops.eval_metric_ops)1607 self.assertIsNone(model_fn_ops.output_alternatives)1608 with session.Session() as sess:1609 self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)1610 def testTrain_withHeadWeights(self):1611 head1 = head_lib.multi_class_head(1612 n_classes=3, label_name="label1", head_name="head1")1613 head2 = head_lib.multi_class_head(1614 n_classes=4, label_name="label2", head_name="head2")1615 head = head_lib.multi_head((head1, head2), (1, .5))1616 labels = {1617 "label1": (1,),1618 "label2": (1,)1619 }1620 model_fn_ops = head.create_model_fn_ops(1621 features={"weights": (2.0, 10.0)},1622 labels=labels,1623 mode=model_fn.ModeKeys.TRAIN,1624 train_op_fn=head_lib.no_op_train_fn,1625 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1626 self.assertIsNone(model_fn_ops.predictions)1627 self.assertIsNotNone(model_fn_ops.loss)1628 self.assertIsNotNone(model_fn_ops.train_op)1629 self.assertTrue(model_fn_ops.eval_metric_ops)1630 self.assertIsNone(model_fn_ops.output_alternatives)1631 with session.Session() as sess:1632 self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3)1633 def testTrain_withDictLogits(self):1634 head1 = head_lib.multi_class_head(1635 n_classes=3, label_name="label1", head_name="head1")1636 head2 = head_lib.multi_class_head(1637 n_classes=4, label_name="label2", head_name="head2")1638 head = head_lib.multi_head((head1, head2))1639 labels = {1640 "label1": (1,),1641 "label2": (1,)1642 }1643 model_fn_ops = head.create_model_fn_ops(1644 features={"weights": (2.0, 10.0)},1645 labels=labels,1646 mode=model_fn.ModeKeys.TRAIN,1647 train_op_fn=head_lib.no_op_train_fn,1648 logits={head1.head_name: ((-0.7, 0.2, .1),),1649 head2.head_name: ((.1, .1, .1, .1),)})1650 self.assertIsNone(model_fn_ops.predictions)1651 self.assertIsNotNone(model_fn_ops.loss)1652 self.assertIsNotNone(model_fn_ops.train_op)1653 self.assertTrue(model_fn_ops.eval_metric_ops)1654 self.assertIsNone(model_fn_ops.output_alternatives)1655 with session.Session() as sess:1656 self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)1657 def testInfer(self):1658 head1 = head_lib.multi_class_head(1659 n_classes=3, label_name="label1", head_name="head1")1660 head2 = head_lib.multi_class_head(1661 n_classes=4, label_name="label2", head_name="head2")1662 head = head_lib.multi_head((head1, head2), (1, .5))1663 labels = {1664 "label1": (1,),1665 "label2": (1,)1666 }1667 model_fn_ops = head.create_model_fn_ops(1668 features={"weights": (2.0, 10.0)},1669 labels=labels,1670 mode=model_fn.ModeKeys.INFER,1671 train_op_fn=head_lib.no_op_train_fn,1672 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1673 self.assertIsNotNone(model_fn_ops.predictions)1674 self.assertIsNone(model_fn_ops.loss)1675 self.assertIsNone(model_fn_ops.train_op)1676 self.assertFalse(model_fn_ops.eval_metric_ops)1677 # Tests predictions keys.1678 self.assertItemsEqual((1679 ("head1", prediction_key.PredictionKey.LOGITS),1680 ("head1", prediction_key.PredictionKey.PROBABILITIES),1681 ("head1", prediction_key.PredictionKey.CLASSES),1682 ("head2", prediction_key.PredictionKey.LOGITS),1683 ("head2", prediction_key.PredictionKey.PROBABILITIES),1684 ("head2", prediction_key.PredictionKey.CLASSES),1685 ), model_fn_ops.predictions.keys())1686 # Tests output alternative.1687 self.assertEquals({1688 "head1": constants.ProblemType.CLASSIFICATION,1689 "head2": constants.ProblemType.CLASSIFICATION,1690 }, {1691 k: v[0] for k, v in six.iteritems(model_fn_ops.output_alternatives)1692 })1693 self.assertItemsEqual((1694 prediction_key.PredictionKey.PROBABILITIES,1695 prediction_key.PredictionKey.CLASSES,1696 ), model_fn_ops.output_alternatives["head1"][1].keys())1697 self.assertItemsEqual((1698 prediction_key.PredictionKey.PROBABILITIES,1699 prediction_key.PredictionKey.CLASSES,1700 ), model_fn_ops.output_alternatives["head2"][1].keys())1701 def testEval(self):1702 head1 = head_lib.multi_class_head(1703 n_classes=3, label_name="label1", head_name="head1")1704 head2 = head_lib.multi_class_head(1705 n_classes=4, label_name="label2", head_name="head2")1706 head = head_lib.multi_head((head1, head2), (1, .5))1707 labels = {1708 "label1": (1,),1709 "label2": (1,)1710 }1711 model_fn_ops = head.create_model_fn_ops(1712 features={"weights": (2.0, 10.0)},1713 labels=labels,1714 mode=model_fn.ModeKeys.EVAL,1715 train_op_fn=head_lib.no_op_train_fn,1716 logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))1717 self.assertIsNotNone(model_fn_ops.predictions)1718 self.assertIsNotNone(model_fn_ops.loss)1719 self.assertIsNone(model_fn_ops.train_op)1720 self.assertIsNotNone(model_fn_ops.eval_metric_ops)...
multi_head.py
Source:multi_head.py
...11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12# See the License for the specific language governing permissions and13# limitations under the License.14# ==============================================================================15"""Abstractions for the head(s) of a model.16"""17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20import six21from tensorflow.python.estimator import model_fn22from tensorflow.python.estimator.canned import head as head_lib23from tensorflow.python.estimator.canned import metric_keys24from tensorflow.python.estimator.export import export_output as export_output_lib25from tensorflow.python.framework import ops26from tensorflow.python.ops import array_ops27from tensorflow.python.ops import control_flow_ops28from tensorflow.python.ops import math_ops29from tensorflow.python.ops import metrics as metrics_lib30from tensorflow.python.saved_model import signature_constants31from tensorflow.python.summary import summary32from tensorflow.python.training import training_util33_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY34def multi_head(heads, head_weights=None):35 """Creates a `_Head` for multi-objective learning.36 This class merges the output of multiple `_Head` objects.37 Specifically:38 * For training, sums losses of each head, calls `train_op_fn` with this39 final loss.40 * For eval, merges metrics by adding `head.name` suffix to the keys in eval41 metrics, such as `precision/head1`, `precision/head2`.42 * For prediction, merges predictions and updates keys in prediction dict to a43 2-tuple, `(head.name, prediction_key)`. Merges `export_outputs` such that44 by default the first head is served.45 Usage:46 ```python47 # In `input_fn` specify labels as a dict keyed by head name:48 def input_fn():49 features = ...50 labels1 = ...51 labels2 = ...52 return features, {'head1': labels1, 'head2': labels2}53 # In `model_fn`, specify logits as a dict keyed by head name:54 def model_fn(features, labels, mode):55 # Create simple heads and specify head name.56 head1 = multi_class_head(n_classes=3, name='head1')57 head2 = binary_classification_head(name='head2')58 # Create multi-head from two simple heads.59 head = multi_head([head1, head2])60 # Create logits for each head, and combine them into a dict.61 logits1, logits2 = logit_fn()62 logits = {'head1': logits1, 'head2': logits2}63 # Return the merged EstimatorSpec64 return head.create_estimator_spec(..., logits=logits, ...)65 # Create an estimator with this model_fn.66 estimator = tf.estimator.Estimator(model_fn=model_fn)67 estimator.train(input_fn=input_fn, steps=100)68 ```69 Also supports `logits` as a `Tensor` of shape70 `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the71 last dimension and distribute it appropriately among the heads. E.g.:72 ```python73 def model_fn(features, labels, mode):74 # Create simple heads and specify head name.75 head1 = multi_class_head(n_classes=3, name='head1')76 head2 = binary_classification_head(name='head2')77 # Create multi-head from two simple heads.78 head = multi_head([head1, head2])79 # Create logits for the multihead.80 logits = logit_fn(logits_dimension=head.logits_dimension)81 # Return the merged EstimatorSpec82 return head.create_estimator_spec(..., logits=logits, ...)83 ```84 Args:85 heads: List or tuple of `_Head` instances. All heads must have `name`86 specified. The first head in the list is the default used at serving time.87 head_weights: Optional list of weights, same length as `heads`. Used when88 merging losses to calculate the weighted sum of losses from each head. If89 `None`, all losses are weighted equally.90 Returns:91 A instance of `_Head` that merges multiple heads.92 Raises:...
net_HBS_solo_resnet18_score_fuse_FFM4_c2_nonlocal_ly24_mout_rnn_ly1_h1_ly4321.py
Source:net_HBS_solo_resnet18_score_fuse_FFM4_c2_nonlocal_ly24_mout_rnn_ly1_h1_ly4321.py
1import torch2import torch.nn as nn3import torch.nn.functional as F4from net.resnet import resnet34, resnet185import copy6from torch.nn import init7from IPython import embed8from auxi.module import FFM_v4, RNNModule9def weights_init_kaiming(m):10 classname = m.__class__.__name__11 if classname.find('Conv') != -1:12 nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')13 elif classname.find('Linear') != -1:14 nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')15 nn.init.constant_(m.bias.data, 0.0)16 elif classname.find('BatchNorm1d') != -1:17 nn.init.normal_(m.weight.data, 1.0, 0.02)18 nn.init.constant_(m.bias.data, 0.0)19class net(nn.Module):20 def __init__(self, config):21 super(net, self).__init__()22 self.net_head = resnet18(pretrained=True, num_classes=26)23 self.head_layer1 = nn.Sequential(*list(self.net_head.children())[:5])24 self.head_layer2 = list(self.net_head.children())[5]25 self.head_layer3 = list(self.net_head.children())[6]26 self.head_layer4 = list(self.net_head.children())[7]27 self.head_cls = list(self.net_head.children())[-1]28 self.net_body = resnet18(pretrained=True, num_classes=26)29 self.body_layer1 = nn.Sequential(*list(self.net_body.children())[:5])30 self.body_layer2 = list(self.net_body.children())[5]31 self.body_layer3 = list(self.net_body.children())[6]32 self.body_layer4 = list(self.net_body.children())[7]33 self.body_cls = list(self.net_body.children())[-1]34 self.net_scene = resnet18(pretrained=True, num_classes=26)35 self.scene_layer1 = nn.Sequential(*list(self.net_scene.children())[:5])36 self.scene_layer2 = list(self.net_scene.children())[5]37 self.scene_layer3 = list(self.net_scene.children())[6]38 self.scene_layer4 = list(self.net_scene.children())[7]39 self.scene_cls = list(self.net_scene.children())[-1]40 self.fc3 = nn.Sequential(nn.Linear(512*3 , 128), nn.Linear(128, 26))41 self.csr_ly1 = RNNModule(input_size=64, hidden_size=64, num_layers=1)42 self.csr_ly2 = RNNModule(input_size=128, hidden_size=128, num_layers=1)43 self.csr_ly3 = RNNModule(input_size=256, hidden_size=256, num_layers=1)44 self.csr_ly4 = RNNModule(input_size=512, hidden_size=512, num_layers=1)45 self.ffm_head_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)46 self.ffm_body_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)47 self.ffm_scene_ly2 = FFM_v4(dimension=2, in_channel=128, inter_channel=128*2)48 self.ffm_head_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)49 self.ffm_body_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)50 self.ffm_scene_ly4 = FFM_v4(dimension=2, in_channel=512, inter_channel=512*2)51 def forward(self, data, mode):52 head_ly1 = self.head_layer1(data['image_head'])53 body_ly1 = self.body_layer1(data['image_body'])54 scene_ly1 = self.scene_layer1(data['image_scene'])55 head_ly1_ = F.adaptive_avg_pool2d(head_ly1, (1,1)).view(head_ly1.size(0), -1)56 body_ly1_ = F.adaptive_avg_pool2d(body_ly1, (1,1)).view(body_ly1.size(0), -1)57 scene_ly1_ = F.adaptive_avg_pool2d(scene_ly1, (1,1)).view(scene_ly1.size(0), -1)58 feature = self.csr_ly1(torch.stack((head_ly1_, body_ly1_, scene_ly1_), dim=1))59 head_ly1 = head_ly1+feature[:,0,:].view(head_ly1.size(0),head_ly1.size(1),1,1).expand_as(head_ly1)60 body_ly1 = body_ly1+feature[:,1,:].view(body_ly1.size(0),body_ly1.size(1),1,1).expand_as(body_ly1)61 scene_ly1 = scene_ly1+feature[:,2,:].view(scene_ly1.size(0),scene_ly1.size(1),1,1).expand_as(scene_ly1)62 head_ly2 = self.head_layer2(head_ly1)63 body_ly2 = self.body_layer2(body_ly1)64 scene_ly2 = self.scene_layer2(scene_ly1)65 head_ly2_ = F.adaptive_avg_pool2d(head_ly2, (1,1)).view(head_ly2.size(0), -1)66 body_ly2_ = F.adaptive_avg_pool2d(body_ly2, (1,1)).view(body_ly2.size(0), -1)67 scene_ly2_ = F.adaptive_avg_pool2d(scene_ly2, (1,1)).view(scene_ly2.size(0), -1)68 feature = self.csr_ly2(torch.stack((head_ly2_, body_ly2_, scene_ly2_), dim=1))69 head_ly2 = head_ly2+feature[:,0,:].view(head_ly2.size(0),head_ly2.size(1),1,1).expand_as(head_ly2)70 body_ly2 = body_ly2+feature[:,1,:].view(body_ly2.size(0),body_ly2.size(1),1,1).expand_as(body_ly2)71 scene_ly2 = scene_ly2+feature[:,2,:].view(scene_ly2.size(0),scene_ly2.size(1),1,1).expand_as(scene_ly2)72 head_ly2 = self.ffm_head_ly2(head_ly2, body_ly2, scene_ly2) + head_ly273 body_ly2 = self.ffm_body_ly2(body_ly2, head_ly2, scene_ly2) + body_ly274 scene_ly2 = self.ffm_scene_ly2(scene_ly2, head_ly2, body_ly2) + scene_ly275 head_ly3 = self.head_layer3(head_ly2)76 body_ly3 = self.body_layer3(body_ly2)77 scene_ly3 = self.scene_layer3(scene_ly2)78 head_ly3_ = F.adaptive_avg_pool2d(head_ly3, (1,1)).view(head_ly3.size(0), -1)79 body_ly3_ = F.adaptive_avg_pool2d(body_ly3, (1,1)).view(body_ly3.size(0), -1)80 scene_ly3_ = F.adaptive_avg_pool2d(scene_ly3, (1,1)).view(scene_ly3.size(0), -1)81 feature = self.csr_ly3(torch.stack((head_ly3_, body_ly3_, scene_ly3_), dim=1))82 head_ly3 = head_ly3+feature[:,0,:].view(head_ly3.size(0),head_ly3.size(1),1,1).expand_as(head_ly3)83 body_ly3 = body_ly3+feature[:,1,:].view(body_ly3.size(0),body_ly3.size(1),1,1).expand_as(body_ly3)84 scene_ly3 = scene_ly3+feature[:,2,:].view(scene_ly3.size(0),scene_ly3.size(1),1,1).expand_as(scene_ly3)85 head_ly4 = self.head_layer4(head_ly3)86 body_ly4 = self.body_layer4(body_ly3)87 scene_ly4 = self.scene_layer4(scene_ly3)88 head_ly4_ = F.adaptive_avg_pool2d(head_ly4, (1,1)).view(head_ly4.size(0), -1)89 body_ly4_ = F.adaptive_avg_pool2d(body_ly4, (1,1)).view(body_ly4.size(0), -1)90 scene_ly4_ = F.adaptive_avg_pool2d(scene_ly4, (1,1)).view(scene_ly4.size(0), -1)91 feature = self.csr_ly4(torch.stack((head_ly4_, body_ly4_, scene_ly4_), dim=1))92 head_ly4 = head_ly4+feature[:,0,:].view(head_ly4.size(0),head_ly4.size(1),1,1).expand_as(head_ly4)93 body_ly4 = body_ly4+feature[:,1,:].view(body_ly4.size(0),body_ly4.size(1),1,1).expand_as(body_ly4)94 scene_ly4 = scene_ly4+feature[:,2,:].view(scene_ly4.size(0),scene_ly4.size(1),1,1).expand_as(scene_ly4)95 head_ly4 = self.ffm_head_ly4(head_ly4, body_ly4, scene_ly4) + head_ly496 body_ly4 = self.ffm_body_ly4(body_ly4, head_ly4, scene_ly4) + body_ly497 scene_ly4 = self.ffm_scene_ly4(scene_ly4, head_ly4, body_ly4) + scene_ly498 head_ly4 = F.adaptive_avg_pool2d(head_ly4, (1,1)).view(head_ly4.size(0), -1)99 body_ly4 = F.adaptive_avg_pool2d(body_ly4, (1,1)).view(body_ly4.size(0), -1)100 scene_ly4 = F.adaptive_avg_pool2d(scene_ly4, (1,1)).view(scene_ly4.size(0), -1)101 out_head = self.head_cls(head_ly4)102 out_body = self.body_cls(body_ly4)103 out_scene = self.scene_cls(scene_ly4)104 out = self.fc3(torch.cat((head_ly4, body_ly4, scene_ly4), dim=1))...
head_tail.py
Source:head_tail.py
1"""Utilities to manage head and tail of elements2The scope is to avoid loosing part of the original text in the final tree.3"""4from .tree import Item5class TokenValue:6 def __init__(self, value):7 self.value = value8 self.pos = None9 self.size = None10 self.head = ""11 self.tail = ""12 def __repr__(self):13 return "TokenValue(%s)" % self.value14 def __str__(self):15 return self.value16class HeadTailLexer:17 """Utility to handle head and tail at lexer time.18 """19 LEXER_ATTR = "_luqum_headtail"20 @classmethod21 def handle(cls, token, orig_value):22 """Handling a token.23 .. note::24 PLY does not gives acces to previous tokens,25 although it does not provide any infrastructure for handling specific state.26 So we use the strategy27 of puting a :py:cls:`HeadTailLexer`instance as an attribute of the lexer28 each time we start a new tokenization.29 """30 # get instance31 if token.lexpos == 0:32 # first token make instance33 instance = cls()34 setattr(token.lexer, cls.LEXER_ATTR, instance)35 else:36 instance = getattr(token.lexer, cls.LEXER_ATTR)37 # handle38 instance.handle_token(token, orig_value)39 def __init__(self):40 self.head = None41 """This will track the head of next element, useful only for first element42 """43 self.last_elt = None44 """This will track the last token, so we can use it to add the tail to it.45 """46 def handle_token(self, token, orig_value):47 """Handle head and tail for tokens48 The scope is to avoid loosing part of the original text and keep it in elements.49 """50 # handle headtail51 if token.type == "SEPARATOR":52 if token.lexpos == 0:53 # spaces at expression start, head for next token54 self.head = token.value55 else:56 # tail of last processed token57 if self.last_elt is not None:58 self.last_elt.value.tail += token.value59 else:60 # if there is a head, apply61 head = self.head62 if head is not None:63 token.value.head = head64 self.head = None65 # keep tracks of token, to apply tail later66 self.last_elt = token67 # also set pos and size68 if isinstance(token.value, (Item, TokenValue)):69 token.value.pos = token.lexpos70 token.value.size = len(orig_value)71token_headtail = HeadTailLexer.handle72class HeadTailManager:73 """Utility to hande head and tail at expression parse time74 """75 def pos(self, p, head_transfer=False, tail_transfer=False):76 """Compute pos and size of element 0 based on it's parts (p[1:])77 :param list p: the parser expression as in PLY78 :param bool head_transfer: True if head of first child will be transfered to p[0]79 :param bool tail_transfer: True if tail of last child wiil be transfered to p[0]80 """81 # pos82 if p[1].pos is not None:83 p[0].pos = p[1].pos84 if not head_transfer:85 # head is'nt transfered, so we are before it86 p[0].pos -= len(p[1].head)87 # size88 p[0].size = sum(89 (elt.size or 0) + len(elt.head or "") + len(elt.tail or "") for elt in p[1:])90 if head_transfer and p[1].head:91 # we account head in size, remove it92 p[0].size -= len(p[1].head)93 last_p = p[len(p) - 1] # negative indexing not supported by PLY94 if tail_transfer and last_p.tail:95 # we account head in size, remove it96 p[0].size -= len(last_p.tail)97 def binary_operation(self, p, op_tail):98 self.pos(p, head_transfer=False, tail_transfer=False)99 # correct size100 p[0].size -= len(op_tail)101 def simple_term(self, p):102 self.pos(p, head_transfer=True, tail_transfer=True)103 p[0].head = p[1].head104 p[0].tail = p[1].tail105 def unary(self, p):106 """OP expr"""107 self.pos(p, head_transfer=True, tail_transfer=False)108 p[0].head = p[1].head109 p[2].head = p[1].tail + p[2].head110 def post_unary(self, p):111 """expr OP"""112 self.pos(p, head_transfer=False, tail_transfer=True)113 p[1].tail += p[2].head114 p[0].tail = p[2].tail115 def paren(self, p):116 """( expr )"""117 self.pos(p, head_transfer=True, tail_transfer=True)118 # p[0] is global element (Group or FieldGroup)119 # p[2] is content120 # p[1] is left parenthesis121 p[0].head = p[1].head122 p[2].head = p[1].tail + p[2].head123 # p[3] is right parenthesis124 p[2].tail += p[3].head125 p[0].tail = p[3].tail126 def range(self, p):127 """[ expr TO expr ]"""128 self.pos(p, head_transfer=True, tail_transfer=True)129 # p[0] is global element (Range)130 # p[2] is lower bound131 p[0].head = p[1].head132 p[2].head = p[1].tail + p[2].head133 # p[3] is TO134 # p[4] is upper bound135 p[2].tail += p[3].head136 p[4].head = p[3].tail + p[4].head137 # p[5] is upper braket138 p[4].tail += p[5].head139 p[0].tail = p[5].tail140 def search_field(self, p):141 """name: expr"""142 self.pos(p, head_transfer=True, tail_transfer=False)143 # p[0] is global element (SearchField)144 # p[1] is search field name145 # p[2] is COLUMN146 p[0].head = p[1].head147 if p[1].tail or p[2].head:148 pass # FIXME: add warning, or handle space between point and name in SearchField ?149 # p[3] is the expression150 p[3].head = p[2].tail + p[3].head151head_tail = HeadTailManager()152"""singleton of HeadTailManager...
LambdaTest’s Playwright tutorial will give you a broader idea about the Playwright automation framework, its unique features, and use cases with examples to exceed your understanding of Playwright testing. This tutorial will give A to Z guidance, from installing the Playwright framework to some best practices and advanced concepts.
Get 100 minutes of automation test minutes FREE!!