Skip to content
Open

style #107

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion _doc/examples/plot_onnx_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from onnx_array_api.reference import compare_onnx_execution
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot


data = load_iris()
X_train, X_test = train_test_split(data.data)
model = GaussianMixture()
Expand Down
1 change: 0 additions & 1 deletion _doc/examples/plot_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from onnx_array_api.ext_test_case import measure_time
from onnx_array_api.ort.ort_optimizers import ort_optimized_model


filename = example_path("data/small.onnx")
optimized = filename + ".optimized.onnx"

Expand Down
1 change: 0 additions & 1 deletion _doc/examples/plot_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from onnx_array_api.ort.ort_profile import ort_profile, merge_ort_profile
from onnx_array_api.plotting.stat_plot import plot_ort_profile


suffix = ""
filename = example_path(f"data/small{suffix}.onnx")
optimized = filename + ".optimized.onnx"
Expand Down
36 changes: 12 additions & 24 deletions _unittests/ut_graph_api/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,38 @@ def call_optimizer(self, onx):
return gr.to_onnx()

def test_remove_unused_nodes(self):
model = onnx.parser.parse_model(
"""
model = onnx.parser.parse_model("""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = Mul(x, x)
}"""
)
}""")
onx = self.call_optimizer(model)
self.assertEqual(len(onx.graph.node), 1)
self.assertEqual(onx.graph.node[0].op_type, "Mul")

def test_initializers(self):
model = onnx.parser.parse_model(
"""
model = onnx.parser.parse_model("""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z)
<float two = {2.0}> {
four = Add(two, two)
z = Mul(x, x)
}"""
)
}""")
self.assertEqual(len(model.graph.initializer), 1)
onx = self.call_optimizer(model)
self.assertEqual(len(onx.graph.node), 1)
self.assertEqual(onx.graph.node[0].op_type, "Mul")
self.assertEqual(len(onx.graph.initializer), 0)

def test_keep_unused_outputs(self):
model = onnx.parser.parse_model(
"""
model = onnx.parser.parse_model("""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[M] z) {
w1, w2, w3 = Split (x)
z = Mul(w3, w3)
}"""
)
}""")
onx = self.call_optimizer(model)
self.assertEqual(len(onx.graph.node), 2)
self.assertEqual(onx.graph.node[0].op_type, "Split")
Expand Down Expand Up @@ -381,30 +375,26 @@ def test_make_nodes_noprefix(self):
self.assertEqualArray(expected, got[0])

def test_node_pattern(self):
model = onnx.parser.parse_model(
"""
model = onnx.parser.parse_model("""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = Mul(x, four)
}"""
)
}""")
gr = GraphBuilder(model)
p = gr.np(index=0)
r = repr(p)
self.assertEqual("NodePattern(index=0, op_type=None, name=None)", r)

def test_update_node_attribute(self):
model = onnx.parser.parse_model(
"""
model = onnx.parser.parse_model("""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = Mul(x, four)
}"""
)
}""")
gr = GraphBuilder(model)
self.assertEqual(len(gr.nodes), 3)
m = gr.update_attribute(gr.np(op_type="Constant"), value_float=float(1))
Expand All @@ -416,15 +406,13 @@ def test_update_node_attribute(self):
self.assertIn("f: 1", str(node))

def test_delete_node_attribute(self):
model = onnx.parser.parse_model(
"""
model = onnx.parser.parse_model("""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = Mul(x, four)
}"""
)
}""")
gr = GraphBuilder(model)
self.assertEqual(len(gr.nodes), 3)
m = gr.update_attribute(
Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_npx/test_sklearn_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor


DEFAULT_OPSET = onnx_opset_version()


Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_ort/test_ort_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.ort.ort_optimizers import ort_optimized_model


DEFAULT_OPSET = onnx_opset_version()


Expand Down
1 change: 0 additions & 1 deletion _unittests/ut_ort/test_sklearn_array_api_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows
from onnx_array_api.ort.ort_tensors import EagerOrtTensor, OrtTensor


DEFAULT_OPSET = onnx_opset_version()


Expand Down
6 changes: 2 additions & 4 deletions _unittests/ut_plotting/test_graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
class TestGraphviz(ExtTestCase):
@classmethod
def _get_graph(cls):
return onnx.parser.parse_model(
"""
return onnx.parser.parse_model("""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = Mul(x, x)
}"""
)
}""")

@skipif_ci_windows("graphviz not installed")
@skipif_ci_apple("graphviz not installed")
Expand Down
48 changes: 16 additions & 32 deletions _unittests/ut_plotting/test_text_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def test_onnx_text_plot_tree_cls_2(self):
model_def = load(f)
res = onnx_text_plot_tree(model_def.graph.node[0])
self.assertIn("n_classes=3", res)
expected = textwrap.dedent(
"""
expected = textwrap.dedent("""
n_classes=3
n_trees=1
----
Expand All @@ -92,8 +91,7 @@ def test_onnx_text_plot_tree_cls_2(self):
-f 0:0 1:0 2:1
+f 0:0 1:1 2:0
+f 0:1 1:0 2:0
"""
).strip(" \n\r")
""").strip(" \n\r")
res = res.replace("np.float32(", "").replace(")", "")
self.assertEqual(expected, res.strip(" \n\r"))

Expand All @@ -104,39 +102,33 @@ def test_onnx_simple_text_plot_kmeans(self):
model.fit(x)
onx = to_onnx(model, x.astype(numpy.float32), target_opset=15)
text = onnx_simple_text_plot(onx)
expected1 = textwrap.dedent(
"""
expected1 = textwrap.dedent("""
ReduceSumSquare(X, axes=[1], keepdims=1) -> Re_reduced0
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
Gemm(X, Ge_Gemmcst, Mu_C0, alpha=-2.00, transB=1) -> Ge_Y0
Add(Re_reduced0, Ge_Y0) -> Ad_C01
Add(Ad_Addcst, Ad_C01) -> Ad_C0
Sqrt(Ad_C0) -> scores
ArgMin(Ad_C0, axis=1, keepdims=0) -> label
"""
).strip(" \n")
expected2 = textwrap.dedent(
"""
""").strip(" \n")
expected2 = textwrap.dedent("""
ReduceSumSquare(X, axes=[1], keepdims=1) -> Re_reduced0
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
Gemm(X, Ge_Gemmcst, Mu_C0, alpha=-2.00, transB=1) -> Ge_Y0
Add(Re_reduced0, Ge_Y0) -> Ad_C01
Add(Ad_Addcst, Ad_C01) -> Ad_C0
Sqrt(Ad_C0) -> scores
ArgMin(Ad_C0, axis=1, keepdims=0) -> label
"""
).strip(" \n")
expected3 = textwrap.dedent(
"""
""").strip(" \n")
expected3 = textwrap.dedent("""
ReduceSumSquare(X, axes=[1], keepdims=1) -> Re_reduced0
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
Gemm(X, Ge_Gemmcst, Mu_C0, alpha=-2.00, transB=1) -> Ge_Y0
Add(Re_reduced0, Ge_Y0) -> Ad_C01
Add(Ad_Addcst, Ad_C01) -> Ad_C0
ArgMin(Ad_C0, axis=1, keepdims=0) -> label
Sqrt(Ad_C0) -> scores
"""
).strip(" \n")
""").strip(" \n")
if expected1 not in text and expected2 not in text and expected3 not in text:
raise AssertionError(f"Unexpected value:\n{text}")

Expand Down Expand Up @@ -165,17 +157,15 @@ def test_onnx_simple_text_plot_toy(self):
{"X": x.astype(numpy.float32)}, outputs={"Y": x}, target_opset=15
)
text = onnx_simple_text_plot(onx, verbose=False)
expected = textwrap.dedent(
"""
expected = textwrap.dedent("""
Add(X, Ad_Addcst) -> Ad_C0
Abs(Ad_C0) -> Ab_Y0
Identity(Ad_Addcst) -> Su_Subcst
Sub(X, Su_Subcst) -> Su_C0
Abs(Su_C0) -> Ab_Y02
Div(Ab_Y0, Ab_Y02) -> Di_C0
Abs(Di_C0) -> Y
"""
).strip(" \n")
""").strip(" \n")
self.assertIn(expected, text)
text2, out, err = self.capture(lambda: onnx_simple_text_plot(onx, verbose=True))
self.assertEqual(text, text2)
Expand All @@ -188,11 +178,9 @@ def test_onnx_simple_text_plot_leaky(self):
{"X": FloatTensorType()}, outputs={"Y": FloatTensorType()}, target_opset=15
)
text = onnx_simple_text_plot(onx)
expected = textwrap.dedent(
"""
expected = textwrap.dedent("""
LeakyRelu(X, alpha=0.50) -> Y
"""
).strip(" \n")
""").strip(" \n")
self.assertIn(expected, text)

def test_onnx_text_plot_io(self):
Expand All @@ -201,11 +189,9 @@ def test_onnx_text_plot_io(self):
{"X": FloatTensorType()}, outputs={"Y": FloatTensorType()}, target_opset=15
)
text = onnx_text_plot_io(onx)
expected = textwrap.dedent(
"""
expected = textwrap.dedent("""
input:
"""
).strip(" \n")
""").strip(" \n")
self.assertIn(expected, text)

def test_onnx_simple_text_plot_if(self):
Expand Down Expand Up @@ -244,11 +230,9 @@ def test_onnx_simple_text_plot_if(self):
{"x1": x1, "x2": x2}, target_opset=opv, outputs=[("y", FloatTensorType())]
)
text = onnx_simple_text_plot(model_def)
expected = textwrap.dedent(
"""
expected = textwrap.dedent("""
input:
"""
).strip(" \n")
""").strip(" \n")
self.assertIn(expected, text)
self.assertIn("If(Gr_C0, else_branch=G1, then_branch=G2)", text)

Expand Down
30 changes: 9 additions & 21 deletions _unittests/ut_reference/test_evaluator_yield.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,55 +431,43 @@ def test_distance_sequence_str(self):
005=|RESULTfloat322:2x2CEIOLinearRegressioY1|RESULTfloat322:2x2CEIOLinearRegressioY1
006~|RESULTfloat322:2x2CEIOAbsY|RESULTfloat322:2x3CEIPAbsZ
007~|OUTPUTfloat322:2x2CEIOY|OUTPUTfloat322:2x2CEIPY
""".replace(
" ", ""
).strip(
"\n "
)
""".replace(" ", "").strip("\n ")
self.maxDiff = None
self.assertEqual(expected, text.replace(" ", "").strip("\n"))

def test_compare_execution(self):
m1 = parse_model(
"""
m1 = parse_model("""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = Mul(x, x)
}"""
)
m2 = parse_model(
"""
}""")
m2 = parse_model("""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
z = Mul(x, x)
}"""
)
}""")
res1, res2, align, dc = compare_onnx_execution(m1, m2)
text = dc.to_str(res1, res2, align)
self.assertIn("CAAA Constant", text)
self.assertEqual(len(align), 5)

def test_compare_execution_discrepancies(self):
m1 = parse_model(
"""
m1 = parse_model("""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
four = Add(two, two)
z = Mul(x, x)
}"""
)
m2 = parse_model(
"""
}""")
m2 = parse_model("""
<ir_version: 8, opset_import: [ "": 18]>
agraph (float[N] x) => (float[N] z) {
two = Constant <value_float=2.0> ()
z = Mul(x, x)
}"""
)
}""")
res1, res2, align, dc = compare_onnx_execution(m1, m2, keep_tensor=True)
text = dc.to_str(res1, res2, align)
print(text)
Expand Down
Loading
Loading