How to use prune method in molecule

Best Python code snippet using molecule_python

ranked_structures_pruner.py

Source:ranked_structures_pruner.py Github

copy

Full Screen

...36 # The "leader" is the first weights-tensor in the list37 return self.params_names[0]38 def is_supported(self, param_name):39 return param_name in self.params_names40 def fraction_to_prune(self, param_name):41 return self.desired_sparsity42 def set_param_mask(self, param, param_name, zeros_mask_dict, meta):43 if not self.is_supported(param_name):44 return45 fraction_to_prune = self.fraction_to_prune(param_name)46 try:47 model = meta['model']48 except TypeError:49 model = None50 return self.prune_to_target_sparsity(param, param_name, zeros_mask_dict, fraction_to_prune, model)51 def prune_to_target_sparsity(self, param, param_name, zeros_mask_dict, target_sparsity, model):52 if not self.is_supported(param_name):53 return54 binary_map = None55 if self.group_dependency == "Leader":56 if target_sparsity != self.last_target_sparsity:57 # Each time we change the target sparsity we need to compute and cache the leader's binary-map.58 # We don't have control over the order that this function is invoked, so the only indication that59 # we need to compute a new leader binary-map is the change of the target_sparsity....

Full Screen

Full Screen

prune_main.py

Source:prune_main.py Github

copy

Full Screen

1import numpy as np2import argparse3import torch4from .prune_base import prune_parse_arguments as prune_base_parse_arguments5from .admm import ADMM, prune_parse_arguments as admm_prune_parse_arguments6from .retrain import Retrain, prune_parse_arguments as retrain_parse_arguments7from .admm import admm_adjust_learning_rate8from .multi_level_admm import MultiLevelADMM9from .multi_level_retrain import MultiLevelRetrain10from .utils_pr import prune_parse_arguments as utils_prune_parse_arguments11prune_algo = None12retrain = None13def main_prune_parse_arguments(parser):14 parser.add_argument('--sp-store-weights', type=str,15 help="store the final weights, "16 "maybe used by the next stage pruning")17 parser.add_argument("--sp-lars", action="store_true",18 help="enable LARS learning rate scheduler")19 parser.add_argument('--sp-lars-trust-coef', type=float, default=0.001,20 help="LARS trust coefficient")21def prune_parse_arguments(parser):22 main_prune_parse_arguments(parser)23 prune_base_parse_arguments(parser)24 admm_prune_parse_arguments(parser)25 utils_prune_parse_arguments(parser)26 retrain_parse_arguments(parser)27def prune_init(args, model, logger=None, fixed_model=None, pre_defined_mask=None):28 global prune_algo, retrain29 if args.sp_admm_multi:30 prune_algo = MultiLevelADMM(args, model, logger)31 return32 if args.sp_retrain:33 if args.sp_prune_before_retrain:34 # For prune before retrain, we need to also set sp-admm-sparsity-type in the command line35 # We need to set sp_admm_update_epoch for ADMM, so set it to 1.36 args.sp_admm_update_epoch = 137 prune_algo = ADMM(args, model, logger, False)38 prune_algo.prune_harden()39 prune_algo = None40 retrain = Retrain(args, model, logger, pre_defined_mask)41 retrain.fix_layer_weight_save()42 return43 if args.sp_retrain_multi:44 prune_algo = None45 retrain = MultiLevelRetrain(args, model, logger)46 if args.sp_admm:47 prune_algo = ADMM(args, model, logger)48 return49 if args.sp_subset_progressive:50 prune_algo = ADMM(args, model, logger)51 retrain = Retrain(args, model, logger)52 return53def prune_update(epoch=0, batch_idx=0):54 if prune_algo != None:55 return prune_algo.prune_update(epoch, batch_idx)56 elif retrain != None:57 return retrain.update_mask(epoch)58def prune_update_grad(opt):59 if retrain != None:60 return retrain.update_grad(opt)61def prune_fix_layer_restore():62 if prune_algo != None:63 return64 if retrain != None:65 return retrain.fix_layer_weight_restore()66def prune_generate_small_resnet_model(mode, ratio1, ratio2):67 pass68def prune_update_loss(loss):69 if prune_algo == None:70 return loss71 return prune_algo.prune_update_loss(loss)72def prune_update_combined_loss(loss):73 if prune_algo == None:74 return loss, loss, loss75 return prune_algo.prune_update_combined_loss(loss)76def prune_harden():77 if prune_algo == None:78 return None79 return prune_algo.prune_harden()80def prune_apply_masks():81 if prune_algo:82 prune_algo.apply_masks()83 if retrain:84 retrain.apply_masks()85 else:86 return87 assert(False)88def prune_apply_masks_on_grads():89 if prune_algo:90 prune_algo.apply_masks_on_grads()91 if retrain:92 retrain.apply_masks_on_grads()93 else:94 return95 assert(False)96def prune_retrain_show_masks(debug=False):97 if retrain == None:98 print("Retrain is None!")99 return100 retrain.show_masks(debug)101def prune_store_weights():102 model = None103 args = None104 logger = None105 if prune_algo :106 model = prune_algo.model107 args = prune_algo.args108 logger = prune_algo.logger109 elif retrain:110 model = retrain.model111 args = retrain.args112 logger = retrain.logger113 else:114 return115 filename = args.sp_store_weights116 if filename is None:117 return118 variables = {}119 if logger:120 p = logger.info121 else:122 p = print123 with torch.no_grad():124 p("Storing weights to {}".format(filename))125 torch.save(model.state_dict(), filename)126def prune_store_prune_params():127 if prune_algo == None:128 return129 return prune_algo.prune_store_params()130def prune_print_sparsity(model=None, logger=None, show_sparse_only=False, compressed_view=False):131 if model is None:132 if prune_algo:133 model = prune_algo.model134 elif retrain:135 model = retrain.model136 else:137 return138 if logger:139 p = logger.info140 elif prune_algo:141 p = prune_algo.logger.info142 elif retrain:143 p = retrain.logger.info144 else:145 p = print146 if show_sparse_only:147 print("The sparsity of all params (>0.01): num_nonzeros, total_num, sparsity")148 total_nz = 0149 total = 0150 for (name, W) in model.named_parameters():151 #print(name, W.shape)152 non_zeros = W.detach().cpu().numpy().astype(np.float32) != 0153 num_nonzeros = np.count_nonzero(non_zeros)154 total_num = non_zeros.size155 sparsity = 1 - (num_nonzeros * 1.0) / total_num156 if sparsity > 0.01:157 print("{}, {}, {}, {}, {}".format(name, non_zeros.shape, num_nonzeros, total_num, sparsity))158 total_nz += num_nonzeros159 total += total_num160 if total > 0:161 print("Overall sparsity for layers with sparsity >0.01: {}".format(1 - float(total_nz)/total))162 else:163 print("All layers are dense!")164 return165 if compressed_view is True:166 total_w_num = 0167 total_w_num_nz = 0168 for (name, W) in model.named_parameters():169 if "weight" in name:170 non_zeros = W.detach().cpu().numpy().astype(np.float32) != 0171 num_nonzeros = np.count_nonzero(non_zeros)172 total_w_num_nz += num_nonzeros173 total_num = non_zeros.size174 total_w_num += total_num175 sparsity = 1 - (total_w_num_nz * 1.0) / total_w_num176 print("The sparsity of all params with 'weights' in its name: num_nonzeros, total_num, sparsity")177 print("{}, {}, {}".format(total_w_num_nz, total_w_num, sparsity))178 return179 print("The sparsity of all parameters: name, num_nonzeros, total_num, shape, sparsity")180 for (name, W) in model.named_parameters():181 non_zeros = W.detach().cpu().numpy().astype(np.float32) != 0182 num_nonzeros = np.count_nonzero(non_zeros)183 total_num = non_zeros.size184 sparsity = 1 - (num_nonzeros * 1.0) / total_num185 print("{}: {}, {}, {}, [{}]".format(name, str(num_nonzeros), str(total_num), non_zeros.shape, str(sparsity)))186def prune_update_learning_rate(optimizer, epoch, args):187 if prune_algo == None:188 return None189 return admm_adjust_learning_rate(optimizer, epoch, args)190# do not use, will be deprecated191def prune_retrain_apply_masks():192 apply_masks()193def prune_generate_yaml(model, sparsity, yaml_filename=None):194 if yaml_filename is None:195 yaml_filename = 'sp_{}.yaml'.format(sparsity)196 with open(yaml_filename,'w') as f:197 f.write("prune_ratios: \n")198 num_w = 0199 for name, W in model.named_parameters():200 print(name, W.shape)201 num_w += W.detach().cpu().numpy().size202 if len(W.detach().cpu().numpy().shape) > 1:203 with open(yaml_filename,'a') as f:204 if 'module.' in name:205 f.write("{}: {}\n".format(name[7:], sparsity))206 else:207 f.write("{}: {}\n".format(name, sparsity))208 print("Yaml file {} generated".format(yaml_filename))209 print("Total number of parameters: {}".format(num_w))...

Full Screen

Full Screen

__init__.py

Source:__init__.py Github

copy

Full Screen

1import os2import sys3# file_dir = os.path.dirname(__file__)4# sys.path.append(file_dir)5from .prune_main import prune_parse_arguments, \6 prune_init, \7 prune_update, \8 prune_update_grad, \9 prune_update_loss, \10 prune_update_combined_loss, \11 prune_harden, \12 prune_apply_masks, \13 prune_apply_masks_on_grads, \14 prune_store_prune_params,\15 prune_update_learning_rate, \16 prune_store_weights, \17 prune_generate_yaml, \18 prune_fix_layer_restore19 #prune_harden_first_before_retrain20# debug functions21from .prune_main import prune_print_sparsity, prune_retrain_show_masks22# will be deprecated functions23from .prune_main import prune_retrain_apply_masks...

Full Screen

Full Screen

Automation Testing Tutorials

Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.

LambdaTest Learning Hubs:

YouTube

You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.

Run molecule automation tests on LambdaTest cloud grid

Perform automation testing on 3000+ real desktop and mobile devices online.

Try LambdaTest Now !!

Get 100 minutes of automation test minutes FREE!!

Next-Gen App & Browser Testing Cloud

Was this article helpful?

Helpful

NotHelpful