Code: Select all
# To see more debug info, please use `graph_module.print_readable()`
Model successfully exported to model.pte
/usr/local/lib/python3.10/dist-packages/executorch/exir/emit/_emitter.py:1512: UserWarning: Mutation on a buffer in the model is detected. ExecuTorch assumes buffers that are mutated in the graph have a meaningless initial state, only the shape and dtype will be serialized.
warnings.warn(
Ich verwende T4 in Google Colab, um diese Funktion auszuführen.
Code: Select all
def quantize_model(model: nn.Module, example_inputs: tuple) -> nn.Module:
print(f"Original model architecture:\n{model}")
# Initialize quantizer with symmetric quantization configuration
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
# Prepare model for quantization
pre_autograd_model = export_for_training(model, example_inputs).module()
prepared_model = prepare_pt2e(pre_autograd_model, quantizer)
# Calibrate and convert the model
prepared_model(*example_inputs)
quantized_model = convert_pt2e(prepared_model)
print(f"Quantized model architecture:\n{quantized_model}")
return quantized_model
def export_to_executorch(model: nn.Module, example_inputs: tuple, output_path: str):
edge_program = to_edge_transform_and_lower(
export(model, example_inputs),
compile_config=EdgeCompileConfig(_check_ir_validity=True),
partitioner=[XnnpackPartitioner()]
)
# Convert to executorch program
executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
ExecutorchBackendConfig(passes=[])
)
# Save the program to file
with open(output_path, "wb") as file:
file.write(executorch_program.buffer)
print(f"Model successfully exported to {output_path}")
def main():
# Define example inputs
example_inputs = (torch.randn(1, 1, 256, 256),)
# Create and quantize model
model = ResEmoteNet() # Assuming this class is defined elsewhere
quantized_model = quantize_model(model, example_inputs)
# Export the quantized model
export_to_executorch(
model=quantized_model,
example_inputs=example_inputs,
output_path="model.pte"
)
if __name__ == "__main__":
main()
Mobile version