Skip to content

Commit 6fc412f

Browse files
committed
atol
1 parent ad3f491 commit 6fc412f

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

_unittests/ut_light_api/test_backend_export.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,19 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
244244
raise NotImplementedError("Unable to run the model node by node.")
245245

246246

247-
backend_test = onnx.backend.test.BackendTest(ExportBackend, __name__)
247+
dft_atol = 1e-3 if sys.platform != "linux" else 1e-5
248+
backend_test = onnx.backend.test.BackendTest(
249+
ExportBackend,
250+
__name__,
251+
test_kwargs={
252+
"test_dft": {"atol": dft_atol},
253+
"test_dft_axis": {"atol": dft_atol},
254+
"test_dft_axis_opset19": {"atol": dft_atol},
255+
"test_dft_inverse": {"atol": dft_atol},
256+
"test_dft_inverse_opset19": {"atol": dft_atol},
257+
"test_dft_opset19": {"atol": dft_atol},
258+
},
259+
)
248260

249261
# The following tests are too slow with the reference implementation (Conv).
250262
backend_test.exclude(

_unittests/ut_reference/test_backend_extended_reference_evaluator.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import platform
3+
import sys
34
import unittest
45
from typing import Any
56
import numpy
@@ -78,10 +79,21 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
7879
raise NotImplementedError("Unable to run the model node by node.")
7980

8081

82+
dft_atol = 1e-3 if sys.platform != "linux" else 1e-5
8183
backend_test = onnx.backend.test.BackendTest(
82-
ExtendedReferenceEvaluatorBackend, __name__
84+
ExtendedReferenceEvaluatorBackend,
85+
__name__,
86+
test_kwargs={
87+
"test_dft": {"atol": dft_atol},
88+
"test_dft_axis": {"atol": dft_atol},
89+
"test_dft_axis_opset19": {"atol": dft_atol},
90+
"test_dft_inverse": {"atol": dft_atol},
91+
"test_dft_inverse_opset19": {"atol": dft_atol},
92+
"test_dft_opset19": {"atol": dft_atol},
93+
},
8394
)
8495

96+
8597
if os.getenv("APPVEYOR"):
8698
backend_test.exclude("(test_vgg19|test_zfnet)")
8799
if platform.architecture()[0] == "32bit":

0 commit comments

Comments
 (0)