Skip to content

Commit

Permalink
[ONNX] Remove useless methods (#1938)
Browse files Browse the repository at this point in the history
Changes:
Removed outdated methods in ONNXGraph
  • Loading branch information
l-bat authored Jun 29, 2023
1 parent 1683a6a commit 68b3037
Showing 1 changed file with 1 addition and 25 deletions.
26 changes: 1 addition & 25 deletions nncf/onnx/graph/onnx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import onnx
Expand Down Expand Up @@ -254,30 +254,6 @@ def get_bias_tensor_port_id(self, node: onnx.NodeProto) -> int:
return weight_definitions.bias_port_id
raise RuntimeError(f"The node {node} does not have bias_port_id attribute")

def _get_weight_tensor_with_reshape(self, node: onnx.NodeProto) -> Tuple[str, np.ndarray]:
"""
Returns node's weight tensor name and its value in the case when reshape node is placed after the weight.
The returned weight tensor will be reshaped according to a shape attribute of the reshape node.
:param node: Reshape node, whose input is weight tensor.
:return: The weight tensor name and its value with applied the reshape operation.
"""
tensor_name = node.output[0]
shape = self.get_initializers_value(node.input[1])
tensor_value = self.get_initializers_value(node.input[0])
reshaped_tensor_value = tensor_value.reshape(shape)
return tensor_name, reshaped_tensor_value

def _get_tensor_from_zero_input(self, node: onnx.NodeProto) -> Tuple[str, np.ndarray]:
"""
Returns the weight tensor name and its value, which is located on the 0-index input port of the node.
:param node: Node, which takes on the 0-index input port id the weight tensor.
:return: The weight tensor name and its value.
"""
tensor_name = self.get_initializer(node.input[0]).name
return tensor_name, self.get_initializers_value(tensor_name)

def get_node_index(self, node_name: str) -> int:
"""
Returns the node index in the model.
Expand Down

0 comments on commit 68b3037

Please sign in to comment.