Skip to content

Commit 1d9d51e

Browse files
committed
Fix ordering bug between untraced_sit_sot and nit_sot
Internally Scan places those outputs to the right of nit_sot. Helper function reordering to match user definition was not handling this correctly
1 parent 16cdb86 commit 1d9d51e

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

pytensor/scan/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def wrap_into_list(x):
661661
else:
662662
actual_n_steps = pt.as_tensor(n_steps, dtype="int64", ndim=0)
663663

664-
# Since we've added all sequences now we need to level them up based on
664+
# Since we've added all sequences now we need to level them off based on
665665
# n_steps or their different shapes
666666
scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
667667

@@ -1212,8 +1212,8 @@ def remove_dimensions(outs, offsets=None):
12121212
rightOrder = (
12131213
mit_sot_rightOrder
12141214
+ sit_sot_rightOrder
1215-
+ untraced_sit_sot_rightOrder
12161215
+ nit_sot_rightOrder
1216+
+ untraced_sit_sot_rightOrder
12171217
)
12181218
scan_out_list = [None] * len(rightOrder)
12191219
for idx, pos in enumerate(rightOrder):

tests/scan/test_basic.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4470,3 +4470,49 @@ def onestep(seq, seq_tm4):
44704470
f_infershape = function([seq, init], out_seq_tm4[1].shape)
44714471
scan_nodes_infershape = scan_nodes_from_fct(f_infershape)
44724472
assert len(scan_nodes_infershape) == 0
4473+
4474+
4475+
@pytest.mark.parametrize("single_step", (True, False))
4476+
def test_scan_mapped_and_non_traced_output_ordering(single_step):
4477+
# Regression test for https://github.com/pymc-devs/pytensor/issues/1796
4478+
4479+
rng = random_generator_type("rng")
4480+
4481+
def x_then_rng(rng):
4482+
next_rng, x = pt.random.normal(rng=rng).owner.outputs
4483+
return x, next_rng
4484+
4485+
xs, final_rng = scan(
4486+
fn=x_then_rng,
4487+
outputs_info=[None, rng],
4488+
n_steps=1 if single_step else 5,
4489+
return_updates=False,
4490+
)
4491+
assert isinstance(xs.type, TensorType)
4492+
assert isinstance(final_rng.type, RandomGeneratorType)
4493+
4494+
def rng_then_x(rng):
4495+
x, next_rng = x_then_rng(rng)
4496+
return next_rng, x
4497+
4498+
final_rng, xs = scan(
4499+
fn=rng_then_x,
4500+
outputs_info=[rng, None],
4501+
n_steps=1 if single_step else 5,
4502+
return_updates=False,
4503+
)
4504+
assert isinstance(xs.type, TensorType)
4505+
assert isinstance(final_rng.type, RandomGeneratorType)
4506+
4507+
def rng_between_xs(rng):
4508+
x, next_rng = x_then_rng(rng)
4509+
return x, next_rng, x + 1, x + 2
4510+
4511+
xs1, final_rng, xs2, xs3 = scan(
4512+
fn=rng_between_xs,
4513+
outputs_info=[None, rng, None, None],
4514+
n_steps=1 if single_step else 5,
4515+
return_updates=False,
4516+
)
4517+
assert all(isinstance(xs.type, TensorType) for xs in (xs1, xs2, xs3))
4518+
assert isinstance(final_rng.type, RandomGeneratorType)

0 commit comments

Comments
 (0)