Best Python code snippet using slash
profile.py
Source:profile.py
1import os2import time3import onnx4import torch5import argparse6from prototype.spring.utils.log_helper import default_logger as logger7class ModelProfile(object):8 def __init__(self, log_result=True):9 self.M = 1e610 self.log_result = log_result11 def count_params(self, model):12 total_param = sum(p.numel() for p in model.parameters())13 conv_param = 014 fc_param = 015 others_param = 016 for name, m in model.named_modules():17 # skip non-leaf modules18 if len(list(m.children())) > 0:19 continue20 # current params21 num = sum(p.numel() for p in m.parameters())22 if isinstance(m, torch.nn.Conv2d):23 conv_param += num24 elif isinstance(m, torch.nn.Linear):25 fc_param += num26 else:27 others_param += num28 total_param /= self.M29 conv_param /= self.M30 fc_param /= self.M31 others_param /= self.M32 if self.log_result:33 logger.info('Profiling information of model on Params.\n \34 Total param: {:.3f}M, conv: {:.3f}M, fc: {:.3f}M, others: {:.3f}M'.format(35 total_param, conv_param , fc_param, others_param))36 return total_param, conv_param, fc_param, others_param37 @torch.no_grad()38 def count_flops(self, model, input_size=(1, 3, 224, 224)):39 """40 args:41 input_size: for example (1, 3, 224, 224)42 """43 flops_dict = {}44 def make_conv2d_hook(name):45 def conv2d_hook(m, input):46 n, _, h, w = input[0].size(0), input[0].size(47 1), input[0].size(2), input[0].size(3)48 flops = n * h * w * m.in_channels * m.out_channels * m.kernel_size[0] * m.kernel_size[1] \49 / m.stride[1] / m.stride[1] / m.groups50 flops_dict[name] = int(flops)51 return conv2d_hook52 hooks = []53 for name, m in model.named_modules():54 if isinstance(m, torch.nn.Conv2d):55 h = m.register_forward_pre_hook(make_conv2d_hook(name))56 hooks.append(h)57 input = torch.zeros(*input_size)58 model.eval()59 _ = model(input)60 model.train()61 total_flops = 062 for k, v in flops_dict.items():63 # logger.info('module {}: {}'.format(k, v))64 total_flops += v65 if self.log_result:66 logger.info('Profiling information of model on FLOPs.\n \67 Total FLOPS: {:.3f}M'.format(total_flops / self.M))68 for h in hooks:69 h.remove()70 return total_flops / self.M71 def test_latency(self,72 model,73 input_size=(1, 3, 224, 224),74 save_path='./',75 hardware='cpu',76 backend='nart',77 batch_size=64,78 data_type='fp32',79 graph_name='',80 force_test=False):81 """ Convert model into ONNX, then test hardware-related latency.82 args:83 model: pytorch model "nn.Module"84 input_size: tuple of 4 int85 save_path: path to save ONNX model86 hardware: hardware type, e.g. ['cpu', 'T4', 'P4', '3559A', '3519A']87 backend: backend type, e.g.88 ['nart', 'ppl2', 'cuda10.0-trt5.0', 'cuda10.0-trt7.0', 'cuda11.0-trt7.1',89 'cuda11.0-nart', 'hisvp-nnie11', 'hisvp-nnie12']90 batch_size: int91 data_type: ['fp32', 'fp16', 'int8']92 graph_name: tag for this model93 force_test: force to test latency no matter whether this model has beed tested94 """95 # unImp96 return 097# logger.info('Converting model into ONNX type...')98# if get_rank() == 0:99# save_prefix = os.path.join(save_path, 'ONNX/model')100# if not os.path.exists(save_prefix):101# os.makedirs(save_prefix)102# with pytorch.convert_mode():103# pytorch.export_onnx(104# model, [input_size],105# filename=save_prefix,106# input_names=['data'],107# output_names=['output'],108# verbose=False,109# cloze=False110# )111# # link.barrier()112# logger.info('Merging BN of the ONNX model...')113# onnx_model = onnx.load(save_prefix + '.onnx')114# graph = OnnxDecoder().decode(onnx_model)115# graph.update_tensor_shape()116# graph.update_topology()117# ConvFuser().run(graph)118# GemmFuser().run(graph)119# DeadCodeElimination().run(graph)120# graph.update_topology()121#122# onnx_model = Model.make_model(graph)123# onnx_model = onnx_model.dump_to_onnx()124# onnx_file_path = save_prefix + '_merged.onnx'125# onnx.save(onnx_model, onnx_file_path)126#127# logger.info('Test latency using ONNX model...')128# latency_client = Latency()129# if get_rank() == 0:130# test_status_success = False131# while not test_status_success:132# ret = latency_client.call(133# hardware_name=hardware,134# backend_name=backend,135# data_type=data_type,136# batch_size=batch_size,137# onnx_file=onnx_file_path,138# graph_name=graph_name,139# force_test=force_test,140# )141# if ret is None:142# test_status_success = False143# else:144# test_status_success = ret['ret']['status'] == 'success'145# logger.info('Whether test succeed: {}'.format(test_status_success))146# logger.info(ret)147# if not test_status_success:148# time.sleep(10)149#150# if self.log_result:151# logger.info('Profiling information of model on Latency.\n \152# In the platform of {}-{}-{}-{}, the latency is {} ms.'.format(153# hardware, backend, data_type, batch_size, ret['cost_time']))154# return ret['cost_time']155MODEL_PROFILER = ModelProfile()156def get_model_profile(model, input_size=224, input_channel: int = 3, batch_size: int = 1, hardware: str = 'T4',157 backend: str = 'cuda11.0-trt7.1', data_type: str = 'int8', force_test: bool = False):158 '''159 args:160 model: model name, e.g. "resnet18c_x0_125"161 or pytorch model "nn.Module"162 input_size: int or tuple of two integers that height and width respectively,163 e.g. input_size=224, input_size=(224, 224)164 input_channel: int165 batch_size: int166 hardware: hardware type, e.g.167 ['cpu', 'T4', 'P4', '3559A', '3519A']168 backend: backend type, e.g.169 ['nart', 'ppl2', 'cuda10.0-trt5.0', 'cuda10.0-trt7.0', 'cuda11.0-trt7.1',170 'cuda11.0-nart', 'hisvp-nnie11', 'hisvp-nnie12']171 data_type: ['fp32', 'fp16', 'int8']172 force_test: force to test latency no matter whether this model has beed tested173 '''174 from prototype.spring.models import SPRING_MODELS_REGISTRY175 MODEL_PROFILER = ModelProfile(log_result=False)176 if isinstance(model, str):177 try:178 graph_name = 'spring.models.' + model179 model = SPRING_MODELS_REGISTRY.get(model)(task='classification')180 except NotImplementedError:181 print('model "{}" not found in SPRING_MODELS_REGISTRY'.format(model))182 else:183 assert isinstance(model, torch.nn.Module), \184 'the argument must be model name or model instance!, but get {}'.format(type(model))185 graph_name = ''186 if isinstance(input_size, tuple) or isinstance(input_size, list):187 input_size = (1, input_channel, input_size[0], input_size[1])188 elif isinstance(input_size, int):189 input_size = (1, input_channel, input_size, input_size)190 else:191 raise ValueError('Expected input_size to be tuple/list or int, but got {} with type {}'192 .format(input_size, type(input_size)))193 total_param, conv_param, fc_param, others_param = MODEL_PROFILER.count_params(model)194 flops = MODEL_PROFILER.count_flops(model, input_size=input_size)195 #latency = MODEL_PROFILER.test_latency(model, input_size=input_size, hardware=hardware,196 #backend=backend, batch_size=batch_size, data_type=data_type,197 #graph_name=graph_name, force_test=force_test)198 logger.info('Profiling information of model on Params.\n \199 Total param: {:.3f}M, conv: {:.3f}M, fc: {:.3f}M, others: {:.3f}M'.format(200 total_param, conv_param, fc_param, others_param))201 logger.info('Profiling information of model on FLOPs.\n \202 Total FLOPS: {:.3f}M'.format(flops))203 #logger.info('Profiling information of model on Latency.\n \204 # In the platform of {}-{}-{}-b{}, the latency is {} ms.'.format(205 # hardware, backend, data_type, batch_size, latency))206 return {'total_param': total_param, 'conv_param': conv_param, 'other_param': others_param,207 'flops': flops}#, 'latency': latency}208if __name__ == '__main__':209 parser = argparse.ArgumentParser('Model profiling')210 parser.add_argument('--model', type=str, help='model name', required=True)211 parser.add_argument('--input-size', default=224, type=int, nargs='+', help='model input size')212 parser.add_argument('--input-channel', default=3, type=int, help='model input channel')213 parser.add_argument('--batch-size', default=64, type=int, help='input batch size')214 parser.add_argument('--hardware', default='cpu', type=str, help='hardware type, e.g. [cpu, T4, P4, 3559A, 3519A ]')215 parser.add_argument('--backend', default='nart', type=str, help='backend type, e.g. [nart, ppl2, cuda10.0-trt5.0, \216 cuda10.0-trt7.0, cuda11.0-trt7.1, cuda11.0-nart, hisvp-nnie11, hisvp-nnie12]')217 parser.add_argument('--data-type', default='fp32', type=str, help='[fp32, fp16, int8]')218 parser.add_argument('--force-test', action='store_true', default=False, help='force test without querying database')219 args = parser.parse_args()220 get_model_profile(args.model, input_size=args.input_size, input_channel=args.input_channel,221 batch_size=args.batch_size, hardware=args.hardware, backend=args.backend,...
latency.py
Source:latency.py
1import os2import re3import ssl4import certifi5import urllib36import mimetypes7import json8import click9from six.moves.urllib.parse import urlencode10class HTTPClient:11 def __init__(self, server_host, pool_size=4, max_parallel_size=4, verify_ssl=True, ssl_ca_cert=None):12 self.server = server_host + "/api/v1/latency"13 self.user_name = get_user_name()14 # cert_reqs15 cert_reqs = ssl.CERT_REQUIRED if verify_ssl is True else ssl.CERT_NONE16 # ca_certs, if not set certificate file, use Mozilla's root certificates17 ca_certs = ssl_ca_cert if ssl_ca_cert is not None else certifi.where()18 # https pool manager19 self.pool_manager = urllib3.PoolManager(20 num_pools=pool_size,21 maxsize=max_parallel_size,22 cert_reqs=cert_reqs,23 ca_certs=ca_certs,24 cert_file=None,25 key_file=None26 )27 def call(self, hardware_name, backend_name, data_type, batch_size, onnx_file, graph_name="", force_test=False):28 if not os.path.exists(onnx_file):29 raise Exception("File {} not existed!".format(onnx_file))30 try:31 # query tuple params32 query_params = [33 ('hardware_name', sanitize_str(hardware_name)),34 ('backend_name', sanitize_str(backend_name)),35 ('data_type', sanitize_str(data_type)),36 ('batch_size', int(batch_size)),37 ('user_name', sanitize_str(self.user_name)),38 ('graph_name', sanitize_str(graph_name)),39 ('force', 'true' if force_test is True else 'false')40 ]41 url = self.server + '?' + urlencode(query_params)42 # post params43 post_params = []44 with open(onnx_file, 'rb') as f:45 filename = os.path.basename(f.name)46 filedata = f.read()47 mimetype = (mimetypes.guess_type(filename)[0] or 'application/octet-stream')48 post_params.append(tuple(['onnx_file', tuple([filename, filedata, mimetype])]))49 # headers50 headers = {51 'Accept': 'application/json',52 'Content-Type': 'multipart/form-data',53 'User-Agent': 'GPDB-client'54 }55 # must del headers['Content-Type'], or the correct Content-Type56 # which generated by urllib3 will be overwritten.57 del headers['Content-Type']58 response = self.pool_manager.request(59 'POST', url,60 fields=post_params,61 encode_multipart=True,62 preload_content=True,63 timeout=None,64 headers=headers65 )66 ret_val = json.loads(response.data)67 return ret_val68 except Exception as e:69 print("HTTPClient Call Error: {}!".format(e))70def get_user_name():71 if 'USER' in os.environ:72 user_name = os.environ['USER']73 elif 'USERNAME' in os.environ:74 user_name = os.environ['USERNAME']75 else:76 user_name = "unknown"77 return user_name78def sanitize_str(ss):79 return re.sub(r'[^0-9a-zA-Z_\.\-]', '', str(ss))80class Latency(object):81 """An information object to pass data between CLI functions."""82 def __init__(self): # Note: This object must have an empty constructor.83 """Create a new instance."""84 self.server = 'http://10.10.40.93:32770/gpdb/'85 def call(self, hardware_name, backend_name, data_type, batch_size, onnx_file,86 graph_name="", force_test=False, print_info=False):87 http_client = HTTPClient(self.server)88 ret_val = http_client.call(hardware_name, backend_name, data_type, batch_size, onnx_file,89 graph_name=graph_name, force_test=force_test)90 if print_info:91 if ret_val is not None:92 print(json.dumps(ret_val))93 else:94 print("Server Not Running!")95 return ret_val96# pass_info is a decorator for functions that pass 'Info' objects.97#: pylint: disable=invalid-name98pass_cfg = click.make_pass_decorator(Latency, ensure=True)99default_cfg = Latency()100@click.option("--server", "-s", type=str, help="spring.models.latency server url",101 default=default_cfg.server, show_default=True)102@click.option('--hardware', 'hardware_name', type=str,103 help='target hardware', required=True)104@click.option('--backend', 'backend_name', type=str,105 help='target backend', required=True)106@click.option('--data_type', 'data_type', type=str,107 help='data type, eg: int8', required=True)108@click.option('--batch_size', 'batch_size', type=int,109 help='batch size, eg: 8', required=True)110@click.option('--model_file', 'model_file',111 help='source model file', required=True)112@click.option('--graph_name', 'graph_name', type=str,113 help='graph name', required=False, default="")114@click.option('--force_test', 'force_test', is_flag=True,115 help='force test without querying database')116@click.command()117@pass_cfg118def ctl(cfg: Latency, server, hardware_name, backend_name, data_type,119 batch_size, model_file, graph_name, force_test):120 """Latency command line tools"""121 cfg.server = server122 cfg.call(hardware_name, backend_name, data_type, batch_size, model_file,123 graph_name=graph_name, force_test=force_test, print_info=True)124def main():...
account_invoice.py
Source:account_invoice.py
1# -*- coding: utf-8 -*-2# Copyright 2015 - Camptocamp SA - Author Vincent Renaville3# Copyright 2016 - Tecnativa - Angel Moya <odoo@tecnativa.com>4# Copyright 2019 - Tecnativa - Pedro M. Baeza5# License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl).6from odoo import models, api, _7from odoo.exceptions import UserError8from odoo.tools import config9class AccountInvoice(models.Model):10 _inherit = "account.invoice"11 @api.multi12 def _test_invoice_line_tax(self):13 errors = []14 error_template = _("Invoice has a line with product %s with no taxes")15 for invoice_line in self.mapped('invoice_line_ids'):16 if not invoice_line.invoice_line_tax_ids:17 error_string = error_template % (invoice_line.name)18 errors.append(error_string)19 if errors:20 raise UserError(21 _('%s\n%s') % (_('No Taxes Defined!'),22 '\n'.join(x for x in errors))23 )24 @api.multi25 def action_invoice_open(self):26 # Always test if it is required by context27 force_test = self.env.context.get('test_tax_required')28 skip_test = any((29 # It usually fails when installing other addons with demo data30 self.sudo().env["ir.module.module"].search([31 ("state", "in", ["to install", "to upgrade"]),32 ("demo", "=", True),33 ]),34 # Avoid breaking unaware addons' tests by default35 config["test_enable"],36 ))37 if force_test or not skip_test:38 self._test_invoice_line_tax()...
Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!