Skip to content

Commit 02f71a9

Browse files
committed
Make SMC Op wrappers respect node dtype not config
1 parent 4747070 commit 02f71a9

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

pymc/distributions/simulator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def make_node(self, x):
338338

339339
def perform(self, node, inputs, outputs):
340340
(x,) = inputs
341-
outputs[0][0] = np.atleast_1d(fn(x)).astype(pytensor.config.floatX)
341+
outputs[0][0] = np.atleast_1d(fn(x)).astype(node.outputs[0].dtype)
342342

343343
return SumStat()
344344

@@ -365,8 +365,6 @@ def make_node(self, epsilon, obs_data, sim_data):
365365

366366
def perform(self, node, inputs, outputs):
367367
eps, obs_data, sim_data = inputs
368-
outputs[0][0] = np.atleast_1d(fn(eps, obs_data, sim_data)).astype(
369-
pytensor.config.floatX
370-
)
368+
outputs[0][0] = np.atleast_1d(fn(eps, obs_data, sim_data)).astype(node.outputs[0].dtype)
371369

372370
return Distance()

0 commit comments

Comments
 (0)