Best Python code snippet using pandera_python
debugging_primitives_test.py
Source:debugging_primitives_test.py
...36except ModuleNotFoundError:37 rich = None38config.parse_flags_with_absl()39debug_print = debugging.debug_print40def _format_multiline(text):41 return textwrap.dedent(text).lstrip()42prev_xla_flags = None43def setUpModule():44 global prev_xla_flags45 # This will control the CPU devices. On TPU we always have 2 devices46 prev_xla_flags = jtu.set_host_platform_device_count(2)47# Reset to previous configuration in case other test modules will be run.48def tearDownModule():49 prev_xla_flags()50# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum51# version is >= 0.3.1552disabled_backends = []53if jaxlib.version < (0, 3, 15):54 disabled_backends.append("tpu")55class DummyDevice:56 def __init__(self, platform, id):57 self.platform = platform58 self.id = id59class DebugPrintTest(jtu.JaxTestCase):60 def tearDown(self):61 super().tearDown()62 dispatch.runtime_tokens.clear()63 @jtu.skip_on_devices(*disabled_backends)64 def test_simple_debug_print_works_in_eager_mode(self):65 def f(x):66 debug_print('x: {}', x)67 with jtu.capture_stdout() as output:68 f(2)69 jax.effects_barrier()70 self.assertEqual(output(), "x: 2\n")71 @jtu.skip_on_devices(*disabled_backends)72 def test_debug_print_works_with_named_format_strings(self):73 def f(x):74 debug_print('x: {x}', x=x)75 with jtu.capture_stdout() as output:76 f(2)77 jax.effects_barrier()78 self.assertEqual(output(), "x: 2\n")79 @jtu.skip_on_devices(*disabled_backends)80 def test_multiple_debug_prints_should_print_multiple_values(self):81 def f(x):82 debug_print('x: {x}', x=x)83 debug_print('y: {y}', y=x + 1)84 with jtu.capture_stdout() as output:85 f(2)86 jax.effects_barrier()87 self.assertEqual(output(), "x: 2\ny: 3\n")88 @jtu.skip_on_devices(*disabled_backends)89 def test_can_stage_out_debug_print(self):90 @jax.jit91 def f(x):92 debug_print('x: {x}', x=x)93 with jtu.capture_stdout() as output:94 f(2)95 jax.effects_barrier()96 self.assertEqual(output(), "x: 2\n")97 @jtu.skip_on_devices(*disabled_backends)98 def test_can_stage_out_debug_print_with_donate_argnums(self):99 if jax.default_backend() not in {"gpu", "tpu"}:100 raise unittest.SkipTest("Donate argnums not supported.")101 def f(x, y):102 debug_print('x: {x}', x=x)103 return x + y104 f = jax.jit(f, donate_argnums=0)105 with jtu.capture_stdout() as output:106 f(2, 3)107 jax.effects_barrier()108 self.assertEqual(output(), "x: 2\n")109 @jtu.skip_on_devices(*disabled_backends)110 def test_can_stage_out_ordered_print(self):111 @jax.jit112 def f(x):113 debug_print('x: {x}', x=x, ordered=True)114 with jtu.capture_stdout() as output:115 f(2)116 jax.effects_barrier()117 self.assertEqual(output(), "x: 2\n")118 @jtu.skip_on_devices(*disabled_backends)119 def test_can_stage_out_ordered_print_with_donate_argnums(self):120 if jax.default_backend() not in {"gpu", "tpu"}:121 raise unittest.SkipTest("Donate argnums not supported.")122 def f(x, y):123 debug_print('x: {x}', x=x, ordered=True)124 return x + y125 f = jax.jit(f, donate_argnums=0)126 with jtu.capture_stdout() as output:127 f(2, 3)128 jax.effects_barrier()129 self.assertEqual(output(), "x: 2\n")130 @jtu.skip_on_devices(*disabled_backends)131 def test_can_stage_out_prints_with_donate_argnums(self):132 if jax.default_backend() not in {"gpu", "tpu"}:133 raise unittest.SkipTest("Donate argnums not supported.")134 def f(x, y):135 debug_print('x: {x}', x=x, ordered=True)136 debug_print('x: {x}', x=x)137 return x + y138 f = jax.jit(f, donate_argnums=0)139 with jtu.capture_stdout() as output:140 f(2, 3)141 jax.effects_barrier()142 self.assertEqual(output(), "x: 2\nx: 2\n")143 @jtu.skip_on_devices(*disabled_backends)144 def test_can_double_stage_out_ordered_print(self):145 @jax.jit146 @jax.jit147 def f(x):148 debug_print('x: {x}', x=x, ordered=True)149 with jtu.capture_stdout() as output:150 f(2)151 jax.effects_barrier()152 self.assertEqual(output(), "x: 2\n")153 @jtu.skip_on_devices(*disabled_backends)154 def test_can_stage_out_ordered_print_with_pytree(self):155 @jax.jit156 def f(x):157 struct = dict(foo=x)158 debug_print('x: {}', struct, ordered=True)159 with jtu.capture_stdout() as output:160 f(np.array(2, np.int32))161 jax.effects_barrier()162 self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")163class DebugPrintTransformationTest(jtu.JaxTestCase):164 def test_debug_print_batching(self):165 @jax.vmap166 def f(x):167 debug_print('hello: {}', x)168 with jtu.capture_stdout() as output:169 f(jnp.arange(2))170 jax.effects_barrier()171 self.assertEqual(output(), "hello: 0\nhello: 1\n")172 def test_debug_print_batching_with_diff_axes(self):173 @functools.partial(jax.vmap, in_axes=(0, 1))174 def f(x, y):175 debug_print('hello: {} {}', x, y)176 with jtu.capture_stdout() as output:177 f(jnp.arange(2), jnp.arange(2)[None])178 jax.effects_barrier()179 self.assertEqual(output(), "hello: 0 [0]\nhello: 1 [1]\n")180 def tested_debug_print_with_nested_vmap(self):181 def f(x):182 debug_print('hello: {}', x)183 # Call with184 # [[0, 1],185 # [2, 3],186 # [4, 5]]187 with jtu.capture_stdout() as output:188 # Should print over 0-axis then 1-axis189 jax.vmap(jax.vmap(f))(jnp.arange(6).reshape((3, 2)))190 jax.effects_barrier()191 self.assertEqual(192 output(),193 "hello: 0\nhello: 2\nhello: 4\nhello: 1\nhello: 3\nhello: 5\n")194 with jtu.capture_stdout() as output:195 # Should print over 1-axis then 0-axis196 jax.vmap(jax.vmap(f, in_axes=0), in_axes=1)(jnp.arange(6).reshape((3, 2)))197 jax.effects_barrier()198 self.assertEqual(199 output(),200 "hello: 0\nhello: 1\nhello: 2\nhello: 3\nhello: 4\nhello: 5\n")201 def test_debug_print_jvp_rule(self):202 def f(x):203 debug_print('x: {}', x)204 with jtu.capture_stdout() as output:205 jax.jvp(f, (1.,), (1.,))206 jax.effects_barrier()207 self.assertEqual(output(), "x: 1.0\n")208 def test_debug_print_vjp_rule(self):209 def f(x):210 debug_print('x: {}', x)211 with jtu.capture_stdout() as output:212 jax.vjp(f, 1.)213 jax.effects_barrier()214 self.assertEqual(output(), "x: 1.0\n")215 def test_debug_print_in_custom_jvp(self):216 @jax.custom_jvp217 def print_tangent(x):218 return x219 @print_tangent.defjvp220 def _(primals, tangents):221 (x,), (t,) = primals, tangents222 debug_print("x_tangent: {}", t)223 return x, t224 def f(x):225 x = jnp.sin(x)226 x = print_tangent(x)227 return x228 with jtu.capture_stdout() as output:229 x = jnp.array(1., jnp.float32)230 jax.jvp(f, (x,), (x,))231 jax.effects_barrier()232 expected = jnp.cos(jnp.array(1., jnp.float32))233 self.assertEqual(output(), f"x_tangent: {expected}\n")234 @unittest.skip("doesn't work yet!") # TODO(mattjj,sharadmv)235 def test_debug_print_in_custom_jvp_linearize(self):236 @jax.custom_jvp237 def print_tangent(x):238 return x239 @print_tangent.defjvp240 def _(primals, tangents):241 (x,), (t,) = primals, tangents242 debug_print("x_tangent: {}", t)243 return x, t244 def f(x):245 x = jnp.sin(x)246 x = print_tangent(x)247 return x248 with jtu.capture_stdout() as output:249 x = jnp.array(1., jnp.float32)250 y, f_lin = jax.linearize(f, x)251 jax.effects_barrier()252 self.assertEqual(output(), "")253 with jtu.capture_stdout() as output:254 _ = f_lin(x)255 jax.effects_barrier()256 expected = jnp.cos(jnp.array(1., jnp.float32))257 self.assertEqual(output(), f"x_tangent: {expected}\n")258 def test_debug_print_grad_with_custom_vjp_rule(self):259 @jax.custom_vjp260 def print_grad(x):261 return x262 def print_grad_fwd(x):263 return x, None264 def print_grad_bwd(_, x_grad):265 debug_print("x_grad: {}", x_grad)266 return (x_grad,)267 print_grad.defvjp(print_grad_fwd, print_grad_bwd)268 def f(x):269 debug_print("x: {}", x)270 x = print_grad(x)271 return jnp.sin(x)272 with jtu.capture_stdout() as output:273 jax.grad(f)(jnp.array(1., jnp.float32))274 jax.effects_barrier()275 expected = jnp.cos(jnp.array(1., jnp.float32))276 self.assertEqual(output(), f"x: 1.0\nx_grad: {expected}\n")277 def test_debug_print_transpose_rule(self):278 def f(x):279 debug_print('should never be called: {}', x)280 return x281 with jtu.capture_stdout() as output:282 jax.linear_transpose(f, 1.)(1.)283 jax.effects_barrier()284 # `debug_print` should be dropped by `partial_eval` because of no285 # output data-dependence.286 self.assertEqual(output(), "")287 @parameterized.named_parameters(jtu.cases_from_list(288 dict(testcase_name="_ordered" if ordered else "", ordered=ordered)289 for ordered in [False, True]))290 def test_remat_of_debug_print(self, ordered):291 def f_(x):292 y = ad_checkpoint.checkpoint_name(x + 1., "y")293 z = ad_checkpoint.checkpoint_name(y * 2., "z")294 debug_print('y: {}, z: {}', y, z, ordered=ordered)295 return ad_checkpoint.checkpoint_name(jnp.exp(z), "w")296 # Policy that saves everything so the debug callback will be saved297 f = ad_checkpoint.checkpoint(f_, policy=ad_checkpoint.everything_saveable)298 with jtu.capture_stdout() as output:299 jax.grad(f)(2.)300 jax.effects_barrier()301 # We expect the print to happen once since it gets saved and isn't302 # rematerialized.303 self.assertEqual(output(), "y: 3.0, z: 6.0\n")304 # Policy that saves nothing so everything gets rematerialized, including the305 # debug callback306 f = ad_checkpoint.checkpoint(f_, policy=ad_checkpoint.nothing_saveable)307 with jtu.capture_stdout() as output:308 jax.grad(f)(2.)309 jax.effects_barrier()310 # We expect the print to happen twice since it is rematerialized.311 self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)312 # Policy that does not save `z` so we will need to rematerialize the print313 f = ad_checkpoint.checkpoint(314 f_, policy=ad_checkpoint.save_any_names_but_these("z"))315 with jtu.capture_stdout() as output:316 jax.grad(f)(2.)317 jax.effects_barrier()318 # We expect the print to happen twice since it is rematerialized.319 self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)320 def save_everything_but_these_names(*names_not_to_save):321 names_not_to_save = frozenset(names_not_to_save)322 def policy(prim, *_, **params):323 if prim is ad_checkpoint.name_p:324 return params['name'] not in names_not_to_save325 return True # Save everything else326 return policy327 # Policy that saves everything but `y`328 f = ad_checkpoint.checkpoint(329 f_, policy=save_everything_but_these_names("y"))330 with jtu.capture_stdout() as output:331 jax.grad(f)(2.)332 jax.effects_barrier()333 # We expect the print to happen once because `y` is not rematerialized and334 # we won't do extra materialization.335 self.assertEqual(output(), "y: 3.0, z: 6.0\n")336 # Policy that saves everything but `y` and `z`337 f = ad_checkpoint.checkpoint(338 f_, policy=save_everything_but_these_names("y", "z"))339 with jtu.capture_stdout() as output:340 jax.grad(f)(2.)341 jax.effects_barrier()342 # We expect the print to happen twice because both `y` and `z` have been343 # rematerialized and we don't have to do any extra rematerialization to344 # print.345 self.assertEqual(output(), "y: 3.0, z: 6.0\n" * 2)346 @jtu.skip_on_devices(*disabled_backends)347 def test_debug_print_in_staged_out_custom_jvp(self):348 @jax.jit349 def f(x):350 @jax.custom_jvp351 def g(x):352 debug_print("hello: {x}", x=x)353 return x354 def g_jvp(primals, tangents):355 (x,), (t,) = primals, tangents356 debug_print("goodbye: {x} {t}", x=x, t=t)357 return x, t358 g.defjvp(g_jvp)359 return g(x)360 with jtu.capture_stdout() as output:361 f(2.)362 jax.effects_barrier()363 self.assertEqual(output(), "hello: 2.0\n")364 with jtu.capture_stdout() as output:365 jax.jvp(f, (2.,), (3.,))366 jax.effects_barrier()367 self.assertEqual(output(), "goodbye: 2.0 3.0\n")368 @jtu.skip_on_devices(*disabled_backends)369 def test_debug_print_in_staged_out_custom_vjp(self):370 @jax.jit371 def f(x):372 @jax.custom_vjp373 def g(x):374 debug_print("hello: {x}", x=x)375 return x376 def g_fwd(x):377 debug_print("hello fwd: {x}", x=x)378 return x, x379 def g_bwd(x, g):380 debug_print("hello bwd: {x} {g}", x=x, g=g)381 return (g,)382 g.defvjp(fwd=g_fwd, bwd=g_bwd)383 return g(x)384 with jtu.capture_stdout() as output:385 f(2.)386 jax.effects_barrier()387 self.assertEqual(output(), "hello: 2.0\n")388 with jtu.capture_stdout() as output:389 _, f_vjp = jax.vjp(f, 2.)390 jax.effects_barrier()391 self.assertEqual(output(), "hello fwd: 2.0\n")392 with jtu.capture_stdout() as output:393 f_vjp(3.0)394 jax.effects_barrier()395 self.assertEqual(output(), "hello bwd: 2.0 3.0\n")396class DebugPrintControlFlowTest(jtu.JaxTestCase):397 def _assertLinesEqual(self, text1, text2):398 def _count(lines):399 return collections.Counter(lines)400 self.assertDictEqual(_count(text1.split("\n")), _count(text2.split("\n")))401 @parameterized.named_parameters(jtu.cases_from_list(402 dict(testcase_name="_ordered" if ordered else "", ordered=ordered)403 for ordered in [False, True]))404 @jtu.skip_on_devices(*disabled_backends)405 def test_can_print_inside_scan(self, ordered):406 def f(xs):407 def _body(carry, x):408 debug_print("carry: {carry}, x: {x}", carry=carry, x=x, ordered=ordered)409 return carry + 1, x + 1410 return lax.scan(_body, 2, xs)411 with jtu.capture_stdout() as output:412 f(jnp.arange(2))413 jax.effects_barrier()414 self.assertEqual(415 output(),416 _format_multiline("""417 carry: 2, x: 0418 carry: 3, x: 1419 """))420 @parameterized.named_parameters(jtu.cases_from_list(421 dict(testcase_name="_ordered" if ordered else "", ordered=ordered)422 for ordered in [False, True]))423 @jtu.skip_on_devices(*disabled_backends)424 def test_can_print_inside_for_loop(self, ordered):425 def f(x):426 def _body(i, x):427 debug_print("i: {i}", i=i, ordered=ordered)428 debug_print("x: {x}", x=x, ordered=ordered)429 return x + 1430 return lax.fori_loop(0, 5, _body, x)431 with jtu.capture_stdout() as output:432 f(2)433 jax.effects_barrier()434 expected = _format_multiline("""435 i: 0436 x: 2437 i: 1438 x: 3439 i: 2440 x: 4441 i: 3442 x: 5443 i: 4444 x: 6445 """)446 if ordered:447 self.assertEqual(output(), expected)448 else:449 self._assertLinesEqual(output(), expected)450 @parameterized.named_parameters(jtu.cases_from_list(451 dict(testcase_name="_ordered" if ordered else "", ordered=ordered)452 for ordered in [False, True]))453 @jtu.skip_on_devices(*disabled_backends)454 def test_can_print_inside_while_loop_body(self, ordered):455 def f(x):456 def _cond(x):457 return x < 10458 def _body(x):459 debug_print("x: {x}", x=x, ordered=ordered)460 return x + 1461 return lax.while_loop(_cond, _body, x)462 with jtu.capture_stdout() as output:463 f(5)464 jax.effects_barrier()465 self.assertEqual(output(), _format_multiline("""466 x: 5467 x: 6468 x: 7469 x: 8470 x: 9471 """))472 @parameterized.named_parameters(jtu.cases_from_list(473 dict(testcase_name="_ordered" if ordered else "", ordered=ordered)474 for ordered in [False, True]))475 @jtu.skip_on_devices(*disabled_backends)476 def test_can_print_inside_while_loop_cond(self, ordered):477 def f(x):478 def _cond(x):479 debug_print("x: {x}", x=x, ordered=ordered)480 return x < 10481 def _body(x):482 return x + 1483 return lax.while_loop(_cond, _body, x)484 with jtu.capture_stdout() as output:485 f(5)486 jax.effects_barrier()487 self.assertEqual(output(), _format_multiline("""488 x: 5489 x: 6490 x: 7491 x: 8492 x: 9493 x: 10494 """))495 with jtu.capture_stdout() as output:496 f(10)497 jax.effects_barrier()498 # Should run the cond once499 self.assertEqual(output(), _format_multiline("""500 x: 10501 """))502 @parameterized.named_parameters(jtu.cases_from_list(503 dict(testcase_name="_ordered" if ordered else "", ordered=ordered)504 for ordered in [False, True]))505 @jtu.skip_on_devices(*disabled_backends)506 def test_can_print_in_batched_while_cond(self, ordered):507 def f(x):508 def _cond(x):509 debug_print("x: {x}", x=x, ordered=ordered)510 return x < 5511 def _body(x):512 return x + 1513 return lax.while_loop(_cond, _body, x)514 with jtu.capture_stdout() as output:515 jax.vmap(f)(jnp.arange(2))516 jax.effects_barrier()517 if ordered:518 expected = _format_multiline("""519 x: 0520 x: 1521 x: 1522 x: 2523 x: 2524 x: 3525 x: 3526 x: 4527 x: 4528 x: 5529 x: 5530 x: 6531 """)532 self.assertEqual(output(), expected)533 else:534 # When the print is unordered, the `cond` is called an additional time535 # after the `_body` runs, so we get more prints.536 expected = _format_multiline("""537 x: 0538 x: 1539 x: 0540 x: 1541 x: 1542 x: 2543 x: 1544 x: 2545 x: 2546 x: 3547 x: 2548 x: 3549 x: 3550 x: 4551 x: 3552 x: 4553 x: 4554 x: 5555 x: 4556 x: 5557 x: 5558 x: 5559 """)560 self._assertLinesEqual(output(), expected)561 @parameterized.named_parameters(jtu.cases_from_list(562 dict(testcase_name="_ordered" if ordered else "", ordered=ordered)563 for ordered in [False, True]))564 @jtu.skip_on_devices(*disabled_backends)565 def test_can_print_inside_cond(self, ordered):566 def f(x):567 def true_fun(x):568 debug_print("true: {}", x, ordered=ordered)569 return x570 def false_fun(x):571 debug_print("false: {}", x, ordered=ordered)572 return x573 return lax.cond(x < 5, true_fun, false_fun, x)574 with jtu.capture_stdout() as output:575 f(5)576 jax.effects_barrier()577 self.assertEqual(output(), _format_multiline("""578 false: 5579 """))580 with jtu.capture_stdout() as output:581 f(4)582 jax.effects_barrier()583 self.assertEqual(output(), _format_multiline("""584 true: 4585 """))586 @parameterized.named_parameters(jtu.cases_from_list(587 dict(testcase_name="_ordered" if ordered else "", ordered=ordered)588 for ordered in [False, True]))589 @jtu.skip_on_devices(*disabled_backends)590 def test_can_print_inside_switch(self, ordered):591 def f(x):592 def b1(x):593 debug_print("b1: {}", x, ordered=ordered)594 return x595 def b2(x):596 debug_print("b2: {}", x, ordered=ordered)597 return x598 def b3(x):599 debug_print("b3: {}", x, ordered=ordered)600 return x601 return lax.switch(x, (b1, b2, b3), x)602 with jtu.capture_stdout() as output:603 f(0)604 jax.effects_barrier()605 self.assertEqual(output(), _format_multiline("""606 b1: 0607 """))608 with jtu.capture_stdout() as output:609 f(1)610 jax.effects_barrier()611 self.assertEqual(output(), _format_multiline("""612 b2: 1613 """))614 with jtu.capture_stdout() as output:615 f(2)616 jax.effects_barrier()617 self.assertEqual(output(), _format_multiline("""618 b3: 2619 """))620class DebugPrintParallelTest(jtu.JaxTestCase):621 def _assertLinesEqual(self, text1, text2):622 def _count(lines):623 return collections.Counter(lines)624 self.assertDictEqual(_count(text1.split("\n")), _count(text2.split("\n")))625 @jtu.skip_on_devices(*disabled_backends)626 def test_ordered_print_not_supported_in_pmap(self):627 @jax.pmap628 def f(x):629 debug_print("{}", x, ordered=True)630 with self.assertRaisesRegex(631 ValueError, "Ordered effects not supported in `pmap`."):632 f(jnp.arange(jax.local_device_count()))633 @jtu.skip_on_devices(*disabled_backends)634 def test_unordered_print_works_in_pmap(self):635 if jax.device_count() < 2:636 raise unittest.SkipTest("Test requires >= 2 devices.")637 @jax.pmap638 def f(x):639 debug_print("hello: {}", x, ordered=False)640 with jtu.capture_stdout() as output:641 f(jnp.arange(jax.local_device_count()))642 jax.effects_barrier()643 lines = [f"hello: {i}\n" for i in range(jax.local_device_count())]644 self._assertLinesEqual(output(), "".join(lines))645 @jax.pmap646 def f2(x):647 debug_print('hello: {}', x)648 debug_print('hello: {}', x + 2)649 with jtu.capture_stdout() as output:650 f2(jnp.arange(2))651 jax.effects_barrier()652 self._assertLinesEqual(output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\n")653 @jtu.skip_on_devices(*disabled_backends)654 def test_unordered_print_with_pjit(self):655 if jax.default_backend() in {"cpu", "gpu"} and jaxlib.version < (0, 3, 16):656 raise unittest.SkipTest("`pjit` of callback not supported.")657 def f(x):658 debug_print("{}", x, ordered=False)659 return x660 mesh = maps.Mesh(np.array(jax.devices()), ['dev'])661 if config.jax_array:662 spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec('dev'))663 out_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec())664 else:665 spec = pjit.PartitionSpec('dev')666 out_spec = pjit.PartitionSpec()667 f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)668 with mesh:669 with jtu.capture_stdout() as output:670 f(np.arange(8, dtype=jnp.int32))671 jax.effects_barrier()672 self.assertEqual(output(), "[0 1 2 3 4 5 6 7]\n")673 def f2(x):674 y = x.dot(x)675 debug_print("{}", y, ordered=False)676 return y677 f2 = pjit.pjit(f2, in_axis_resources=spec, out_axis_resources=out_spec)678 with maps.Mesh(np.array(jax.devices()), ['dev']):679 with jtu.capture_stdout() as output:680 f2(np.arange(8, dtype=jnp.int32))681 jax.effects_barrier()682 self.assertEqual(output(), "140\n")683 @jtu.skip_on_devices(*disabled_backends)684 def test_unordered_print_of_pjit_of_while(self):685 if (jax.default_backend() in {"cpu", "gpu"}686 and jaxlib.xla_extension_version < 81):687 raise unittest.SkipTest("`pjit` of callback not supported.")688 def f(x):689 def cond(carry):690 i, *_ = carry691 return i < 5692 def body(carry):693 i, x = carry694 debug_print("{}", x, ordered=False)695 x = x + 1696 return (i + 1, x)697 return lax.while_loop(cond, body, (0, x))[1]698 mesh = maps.Mesh(np.array(jax.devices()), ['dev'])699 if config.jax_array:700 spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec('dev'))701 else:702 spec = pjit.PartitionSpec('dev')703 f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)704 with mesh:705 with jtu.capture_stdout() as output:706 f(np.arange(8, dtype=jnp.int32))707 jax.effects_barrier()708 self.assertEqual(output(),709 "[0 1 2 3 4 5 6 7]\n"710 "[1 2 3 4 5 6 7 8]\n"711 "[2 3 4 5 6 7 8 9]\n"712 "[ 3 4 5 6 7 8 9 10]\n"713 "[ 4 5 6 7 8 9 10 11]\n")714 @jtu.skip_on_devices(*disabled_backends)715 def test_unordered_print_of_pjit_of_xmap(self):716 if (jax.default_backend() in {"cpu", "gpu"}717 and jaxlib.xla_extension_version < 81):718 raise unittest.SkipTest("`pjit` of callback not supported.")719 def f(x):720 def foo(x):721 idx = lax.axis_index('foo')722 debug_print("{idx}: {x}", idx=idx, x=x)723 return jnp.mean(x, axis=['foo'])724 out = maps.xmap(foo, in_axes=['foo'], out_axes=[...])(x)725 debug_print("Out: {}", out)726 return out727 mesh = maps.Mesh(np.array(jax.devices()), ['dev'])728 if config.jax_array:729 in_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec('dev'))730 out_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec())731 else:732 in_spec = pjit.PartitionSpec('dev')733 out_spec = pjit.PartitionSpec()734 f = pjit.pjit(f, in_axis_resources=in_spec, out_axis_resources=out_spec)735 with mesh:736 with jtu.capture_stdout() as output:737 f(jnp.arange(8, dtype=jnp.int32) * 2)738 lines = ["0: 0", "1: 2", "2: 4", "3: 6", "4: 8", "5: 10", "6: 12",739 "7: 14", "Out: 7.0", ""]740 jax.effects_barrier()741 self._assertLinesEqual(output(), "\n".join(lines))742 @jtu.skip_on_devices(*disabled_backends)743 def test_unordered_print_with_xmap(self):744 def f(x):745 debug_print("{}", x, ordered=False)746 f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu',747 axis_resources={'a': 'dev'})748 with maps.Mesh(np.array(jax.devices()), ['dev']):749 with jtu.capture_stdout() as output:750 f(np.arange(40))751 jax.effects_barrier()752 lines = [f"{i}\n" for i in range(40)]753 self._assertLinesEqual(output(), "".join(lines))754 @jtu.skip_on_devices(*disabled_backends)755 def test_unordered_print_works_in_pmap_of_while(self):756 if jax.device_count() < 2:757 raise unittest.SkipTest("Test requires >= 2 devices.")758 @jax.pmap759 def f(x):760 def cond(x):761 return x < 3762 def body(x):763 debug_print("hello: {}", x, ordered=False)764 return x + 1765 return lax.while_loop(cond, body, x)766 with jtu.capture_stdout() as output:767 f(jnp.arange(2))768 jax.effects_barrier()769 self._assertLinesEqual(770 output(), "hello: 0\nhello: 1\nhello: 2\n"771 "hello: 1\nhello: 2\n")772 @jtu.skip_on_devices(*disabled_backends)773 def test_incorrectly_formatted_string(self):774 @jax.jit775 def f(x):776 debug_print("hello: {x}", x)777 return x778 with self.assertRaises(KeyError):779 f(jnp.arange(2))780 jax.effects_barrier()781 @jax.jit782 def f(x):783 debug_print("hello: {}", x=x)784 return x785 with self.assertRaises(IndexError):786 f(jnp.arange(2))787 jax.effects_barrier()788 @jtu.skip_on_devices(*disabled_backends)789 def test_format_string_errors_with_unused_args(self):790 @jax.jit791 def f(x):792 debug_print("hello: {x}", x=x, y=x)793 return x794 with self.assertRaisesRegex(ValueError, "Unused keyword arguments"):795 f(jnp.arange(2))796 jax.effects_barrier()797 @jax.jit798 def g(x):799 debug_print("hello", x)800 return x801 with self.assertRaisesRegex(ValueError, "Unused positional arguments"):802 g(jnp.arange(2))803 jax.effects_barrier()804 @jtu.skip_on_devices(*disabled_backends)805 def test_accidental_fstring(self):806 @jax.jit807 def f(x):808 debug_print(f"hello: {x}", x=x)809 return x810 with self.assertRaisesRegex(ValueError, "You may be passing an f-string"):811 f(jnp.arange(2))812 jax.effects_barrier()813class VisualizeShardingTest(jtu.JaxTestCase):814 def _create_devices(self, shape):815 num_devices = np.prod(shape)816 devices = [DummyDevice("CPU", i) for i in range(num_devices)]817 return np.array(devices).reshape(shape)818 def test_trivial_sharding(self):819 mesh = maps.Mesh(self._create_devices(1), ['x'])820 pspec = pjit.PartitionSpec('x')821 sd = sharding.MeshPspecSharding(mesh, pspec)822 shape = (5,)823 with jtu.capture_stdout() as output:824 debugging.visualize_sharding(shape, sd)825 self.assertEqual(output(), _format_multiline("""826 âââââââââ827 â CPU 0 â828 âââââââââ829 """))830 def test_trivial_sharding_with_scale(self):831 mesh = maps.Mesh(self._create_devices(1), ['x'])832 pspec = pjit.PartitionSpec('x')833 sd = sharding.MeshPspecSharding(mesh, pspec)834 shape = (5,)835 with jtu.capture_stdout() as output:836 debugging.visualize_sharding(shape, sd, scale=8.)837 self.assertEqual(output(), _format_multiline("""838 ââââââââââââââââ839 â CPU 0 â840 ââââââââââââââââ841 """))842 def test_full_sharding(self):843 mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])844 pspec = pjit.PartitionSpec('x', 'y')845 sd = sharding.MeshPspecSharding(mesh, pspec)846 shape = (8, 8)847 with jtu.capture_stdout() as output:848 debugging.visualize_sharding(shape, sd)849 expected = _format_multiline("""850 âââââââââ¬ââââââââ¬ââââââââ¬ââââââââ851 â CPU 0 â CPU 1 â CPU 2 â CPU 3 â852 âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤853 â CPU 4 â CPU 5 â CPU 6 â CPU 7 â854 âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤855 â CPU 8 â CPU 9 âCPU 10 âCPU 11 â856 âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤857 âCPU 12 âCPU 13 âCPU 14 âCPU 15 â858 âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤859 âCPU 16 âCPU 17 âCPU 18 âCPU 19 â860 âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤861 âCPU 20 âCPU 21 âCPU 22 âCPU 23 â862 âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤863 âCPU 24 âCPU 25 âCPU 26 âCPU 27 â864 âââââââââ¼ââââââââ¼ââââââââ¼ââââââââ¤865 âCPU 28 âCPU 29 âCPU 30 âCPU 31 â866 âââââââââ´ââââââââ´ââââââââ´ââââââââ867 """)868 self.assertEqual(output(), expected)869 def test_sharding_with_replication(self):870 shape = (8, 8)871 mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])872 pspec = pjit.PartitionSpec('x', None)873 sd = sharding.MeshPspecSharding(mesh, pspec)874 with jtu.capture_stdout() as output:875 debugging.visualize_sharding(shape, sd)876 expected = _format_multiline("""877 âââââââââââââââââââââââââ878 â CPU 0,1,2,3 â879 âââââââââââââââââââââââââ¤880 â CPU 4,5,6,7 â881 âââââââââââââââââââââââââ¤882 â CPU 8,9,10,11 â883 âââââââââââââââââââââââââ¤884 â CPU 12,13,14,15 â885 âââââââââââââââââââââââââ¤886 â CPU 16,17,18,19 â887 âââââââââââââââââââââââââ¤888 â CPU 20,21,22,23 â889 âââââââââââââââââââââââââ¤890 â CPU 24,25,26,27 â891 âââââââââââââââââââââââââ¤892 â CPU 28,29,30,31 â893 âââââââââââââââââââââââââ894 """)895 self.assertEqual(output(), expected)896 mesh = maps.Mesh(self._create_devices((4, 2)), ['x', 'y'])897 pspec = pjit.PartitionSpec(None, 'y')898 sd = sharding.MeshPspecSharding(mesh, pspec)899 with jtu.capture_stdout() as output:900 debugging.visualize_sharding(shape, sd)901 expected = _format_multiline("""902 âââââââââââââ¬ââââââââââââ903 â â â904 â â â905 â â â906 â â â907 âCPU 0,2,4,6âCPU 1,3,5,7â908 â â â909 â â â910 â â â911 â â â912 âââââââââââââ´ââââââââââââ913 """)914 self.assertEqual(output(), expected)915 def test_visualize_wide_array(self):916 shape = (128, 10000)917 mesh = maps.Mesh(self._create_devices((8, 4)), ['x', 'y'])918 pspec = pjit.PartitionSpec('x', None)919 sd = sharding.MeshPspecSharding(mesh, pspec)920 with jtu.capture_stdout() as output:921 debugging.visualize_sharding(shape, sd)922 expected = _format_multiline("""923 ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ924 â CPU 0,1,2,3 â925 ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤926 â CPU 4,5,6,7 â927 ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤928 â CPU 8,9,10,11 â929 ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤930 â CPU 12,13,14,15 â931 ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤932 â CPU 16,17,18,19 â933 ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤934 â CPU 20,21,22,23 â935 ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤936 â CPU 24,25,26,27 â937 ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ¤938 â CPU 28,29,30,31 â939 ââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââââ940 """)941 self.assertEqual(output(), expected)942 def test_visualize_pmap_sharding(self):943 ss = pxla.ShardingSpec(944 sharding=(pxla.Unstacked(8),),945 mesh_mapping=(pxla.ShardedAxis(0),))946 sd = sharding.PmapSharding(self._create_devices(8), ss)947 shape = (8,)948 with jtu.capture_stdout() as output:949 debugging.visualize_sharding(shape, sd)950 expected = _format_multiline("""951 âââââââââ¬ââââââââ¬ââââââââ¬ââââââââ¬ââââââââ¬ââââââââ¬ââââââââ¬ââââââââ952 â CPU 0 â CPU 1 â CPU 2 â CPU 3 â CPU 4 â CPU 5 â CPU 6 â CPU 7 â953 âââââââââ´ââââââââ´ââââââââ´ââââââââ´ââââââââ´ââââââââ´ââââââââ´ââââââââ954 """)955 self.assertEqual(output(), expected)956 ss = pxla.ShardingSpec(957 sharding=(pxla.Unstacked(8), pxla.NoSharding()),958 mesh_mapping=(pxla.ShardedAxis(0),))959 sd = sharding.PmapSharding(self._create_devices(8), ss)960 shape = (8, 2)961 with jtu.capture_stdout() as output:962 debugging.visualize_sharding(shape, sd)963 expected = _format_multiline("""964 âââââââââ965 â CPU 0 â966 âââââââââ¤967 â CPU 1 â968 âââââââââ¤969 â CPU 2 â970 âââââââââ¤971 â CPU 3 â972 âââââââââ¤973 â CPU 4 â974 âââââââââ¤975 â CPU 5 â976 âââââââââ¤977 â CPU 6 â...
debugger_test.py
Source:debugger_test.py
...33 for command in commands:34 fake_stdin.write(command + "\n")35 fake_stdin.seek(0)36 return fake_stdin, io.StringIO()37def _format_multiline(text):38 return textwrap.dedent(text).lstrip()39prev_xla_flags = None40def setUpModule():41 global prev_xla_flags42 # This will control the CPU devices. On TPU we always have 2 devices43 prev_xla_flags = jtu.set_host_platform_device_count(2)44# Reset to previous configuration in case other test modules will be run.45def tearDownModule():46 prev_xla_flags()47# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum48# version is >= 0.3.1549disabled_backends = []50if jaxlib.version < (0, 3, 15):51 disabled_backends.append("tpu")52class CliDebuggerTest(jtu.JaxTestCase):53 @jtu.skip_on_devices(*disabled_backends)54 def test_debugger_eof(self):55 stdin, stdout = make_fake_stdin_stdout([])56 def f(x):57 y = jnp.sin(x)58 debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")59 return y60 with self.assertRaises(SystemExit):61 f(2.)62 jax.effects_barrier()63 @jtu.skip_on_devices(*disabled_backends)64 def test_debugger_can_continue(self):65 stdin, stdout = make_fake_stdin_stdout(["c"])66 def f(x):67 y = jnp.sin(x)68 debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")69 return y70 f(2.)71 jax.effects_barrier()72 expected = _format_multiline(r"""73 Entering jdb:74 (jdb) """)75 self.assertEqual(stdout.getvalue(), expected)76 @jtu.skip_on_devices(*disabled_backends)77 def test_debugger_can_print_value(self):78 stdin, stdout = make_fake_stdin_stdout(["p x", "c"])79 def f(x):80 y = jnp.sin(x)81 debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")82 return y83 expected = _format_multiline(r"""84 Entering jdb:85 (jdb) DeviceArray(2., dtype=float32)86 (jdb) """)87 f(jnp.array(2., jnp.float32))88 jax.effects_barrier()89 self.assertEqual(stdout.getvalue(), expected)90 @jtu.skip_on_devices(*disabled_backends)91 def test_debugger_can_print_value_in_jit(self):92 stdin, stdout = make_fake_stdin_stdout(["p x", "c"])93 @jax.jit94 def f(x):95 y = jnp.sin(x)96 debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")97 return y98 expected = _format_multiline(r"""99 Entering jdb:100 (jdb) array(2., dtype=float32)101 (jdb) """)102 f(jnp.array(2., jnp.float32))103 jax.effects_barrier()104 self.assertEqual(stdout.getvalue(), expected)105 @jtu.skip_on_devices(*disabled_backends)106 def test_debugger_can_print_multiple_values(self):107 stdin, stdout = make_fake_stdin_stdout(["p x, y", "c"])108 @jax.jit109 def f(x):110 y = x + 1.111 debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")112 return y113 expected = _format_multiline(r"""114 Entering jdb:115 (jdb) (array(2., dtype=float32), array(3., dtype=float32))116 (jdb) """)117 f(jnp.array(2., jnp.float32))118 jax.effects_barrier()119 self.assertEqual(stdout.getvalue(), expected)120 @jtu.skip_on_devices(*disabled_backends)121 def test_debugger_can_print_context(self):122 stdin, stdout = make_fake_stdin_stdout(["l", "c"])123 @jax.jit124 def f(x):125 y = jnp.sin(x)126 debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")127 return y128 f(2.)129 jax.effects_barrier()130 expected = _format_multiline(r"""131 Entering jdb:132 \(jdb\) > .*debugger_test\.py\([0-9]+\)133 @jax\.jit134 def f\(x\):135 y = jnp\.sin\(x\)136 -> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)137 return y138 .*139 \(jdb\) """)140 self.assertRegex(stdout.getvalue(), expected)141 @jtu.skip_on_devices(*disabled_backends)142 def test_debugger_can_print_backtrace(self):143 stdin, stdout = make_fake_stdin_stdout(["bt", "c"])144 @jax.jit145 def f(x):146 y = jnp.sin(x)147 debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")148 return y149 expected = _format_multiline(r"""150 Entering jdb:.*151 \(jdb\) Traceback:.*152 """)153 f(2.)154 jax.effects_barrier()155 self.assertRegex(stdout.getvalue(), expected)156 @jtu.skip_on_devices(*disabled_backends)157 def test_debugger_can_work_with_multiple_stack_frames(self):158 stdin, stdout = make_fake_stdin_stdout(["l", "u", "p x", "d", "c"])159 def f(x):160 y = jnp.sin(x)161 debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")162 return y163 @jax.jit164 def g(x):165 y = f(x)166 return jnp.exp(y)167 expected = _format_multiline(r"""168 Entering jdb:169 \(jdb\) > .*debugger_test\.py\([0-9]+\)170 def f\(x\):171 y = jnp\.sin\(x\)172 -> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)173 return y174 .*175 \(jdb\) > .*debugger_test\.py\([0-9]+\).*176 @jax\.jit177 def g\(x\):178 -> y = f\(x\)179 return jnp\.exp\(y\)180 .*181 \(jdb\) array\(2\., dtype=float32\)182 \(jdb\) > .*debugger_test\.py\([0-9]+\)183 def f\(x\):184 y = jnp\.sin\(x\)185 -> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)186 return y187 .*188 \(jdb\) """)189 g(jnp.array(2., jnp.float32))190 jax.effects_barrier()191 self.assertRegex(stdout.getvalue(), expected)192 @jtu.skip_on_devices(*disabled_backends)193 def test_can_use_multiple_breakpoints(self):194 stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])195 def f(x):196 y = x + 1.197 debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True,198 backend="cli")199 return y200 @jax.jit201 def g(x):202 y = f(x) * 2.203 debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True,204 backend="cli")205 return jnp.exp(y)206 expected = _format_multiline(r"""207 Entering jdb:208 (jdb) array(3., dtype=float32)209 (jdb) Entering jdb:210 (jdb) array(6., dtype=float32)211 (jdb) """)212 g(jnp.array(2., jnp.float32))213 jax.effects_barrier()214 self.assertEqual(stdout.getvalue(), expected)215 @jtu.skip_on_devices(*disabled_backends)216 def test_debugger_works_with_vmap(self):217 stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])218 # On TPU, the breakpoints can be reordered inside of vmap but can be fixed219 # by ordering sends.220 # TODO(sharadmv): change back to ordered = False when sends are ordered221 ordered = jax.default_backend() == "tpu"222 def f(x):223 y = x + 1.224 debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=ordered,225 backend="cli")226 return 2. * y227 @jax.jit228 @jax.vmap229 def g(x):230 y = f(x)231 return jnp.exp(y)232 expected = _format_multiline(r"""233 Entering jdb:234 (jdb) array(1., dtype=float32)235 (jdb) Entering jdb:236 (jdb) array(2., dtype=float32)237 (jdb) """)238 g(jnp.arange(2., dtype=jnp.float32))239 jax.effects_barrier()240 self.assertEqual(stdout.getvalue(), expected)241 @jtu.skip_on_devices(*disabled_backends)242 def test_debugger_works_with_pmap(self):243 if jax.local_device_count() < 2:244 raise unittest.SkipTest("Test requires >= 2 devices.")245 stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])246 def f(x):247 y = jnp.sin(x)248 debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")249 return y250 @jax.pmap251 def g(x):252 y = f(x)253 return jnp.exp(y)254 expected = _format_multiline(r"""255 Entering jdb:256 \(jdb\) array\(.*, dtype=float32\)257 \(jdb\) Entering jdb:258 \(jdb\) array\(.*, dtype=float32\)259 \(jdb\) """)260 g(jnp.arange(2., dtype=jnp.float32))261 jax.effects_barrier()262 self.assertRegex(stdout.getvalue(), expected)263 @jtu.skip_on_devices(*disabled_backends)264 def test_debugger_works_with_pjit(self):265 if jax.default_backend() != "tpu":266 raise unittest.SkipTest("`pjit` doesn't work with CustomCall.")267 stdin, stdout = make_fake_stdin_stdout(["p y", "c"])268 def f(x):269 y = x + 1270 debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")271 return y272 def g(x):273 y = f(x)274 return jnp.exp(y)275 g = pjit.pjit(g, in_axis_resources=pjit.PartitionSpec("dev"),276 out_axis_resources=pjit.PartitionSpec("dev"))277 with maps.Mesh(np.array(jax.devices()), ["dev"]):278 arr = (1 + np.arange(8)).astype(np.int32)279 expected = _format_multiline(r"""280 Entering jdb:281 \(jdb\) {}282 \(jdb\) """.format(re.escape(repr(arr))))283 g(jnp.arange(8, dtype=jnp.int32))284 jax.effects_barrier()285 print(stdout.getvalue())286 print(expected)287 self.assertRegex(stdout.getvalue(), expected)288 @jtu.skip_on_devices(*disabled_backends)289 def test_debugger_uses_local_before_global_scope(self):290 stdin, stdout = make_fake_stdin_stdout(["p foo", "c"])291 foo = "outer"292 def f(x):293 foo = "inner"294 debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")295 del foo296 return x297 del foo298 expected = _format_multiline(r"""299 Entering jdb:300 \(jdb\) 'inner'301 \(jdb\) """)302 f(2.)303 jax.effects_barrier()304 print(stdout.getvalue())305 print(expected)306 self.assertRegex(stdout.getvalue(), expected)307if __name__ == '__main__':...
Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!