Best Python code snippet using slash
test_problems.py
Source:test_problems.py
1#2# This file is part of the chi repository3# (https://github.com/DavAug/chi/) which is released under the4# BSD 3-clause license. See accompanying LICENSE.md for copyright notice and5# full license details.6#7import copy8import unittest9import numpy as np10import pandas as pd11import pints12import chi13from chi.library import ModelLibrary14class TestProblemModellingControllerPDProblem(unittest.TestCase):15 """16 Tests the chi.ProblemModellingController class on a PD modelling17 problem.18 """19 @classmethod20 def setUpClass(cls):21 # Create test dataset22 ids_v = [0, 0, 0, 1, 1, 1, 2, 2]23 times_v = [0, 1, 2, 2, np.nan, 4, 1, 3]24 volumes = [np.nan, 0.3, 0.2, 0.5, 0.1, 0.2, 0.234, np.nan]25 ids_c = [0, 0, 1, 1]26 times_c = [0, 1, 2, np.nan]27 cytokines = [3.4, 0.3, 0.5, np.nan]28 ids_d = [0, 1, 1, 1, 2, 2]29 times_d = [0, np.nan, 4, 1, 3, 3]30 dose = [3.4, np.nan, 0.5, 0.5, np.nan, np.nan]31 duration = [0.01, np.nan, 0.31, np.nan, 0.5, np.nan]32 ids_cov = [0, 1, 2]33 times_cov = [np.nan, 1, np.nan]34 age = [10, 14, 12]35 cls.data = pd.DataFrame({36 'ID': ids_v + ids_c + ids_d + ids_cov,37 'Time': times_v + times_c + times_d + times_cov,38 'Observable':39 ['Tumour volume'] * 8 + ['IL 6'] * 4 + [np.nan] * 6 +40 ['Age'] * 3,41 'Value': volumes + cytokines + [np.nan] * 6 + age,42 'Dose': [np.nan] * 12 + dose + [np.nan] * 3,43 'Duration': [np.nan] * 12 + duration + [np.nan] * 3})44 # Test case I: create PD modelling problem45 lib = ModelLibrary()46 path = lib.tumour_growth_inhibition_model_koch()47 cls.pd_model = chi.PharmacodynamicModel(path)48 cls.error_model = chi.ConstantAndMultiplicativeGaussianErrorModel()49 cls.pd_problem = chi.ProblemModellingController(50 cls.pd_model, cls.error_model)51 # Test case II: create PKPD modelling problem52 lib = ModelLibrary()53 path = lib.erlotinib_tumour_growth_inhibition_model()54 cls.pkpd_model = chi.PharmacokineticModel(path)55 cls.pkpd_model.set_outputs([56 'central.drug_concentration',57 'myokit.tumour_volume'])58 cls.error_models = [59 chi.ConstantAndMultiplicativeGaussianErrorModel(),60 chi.ConstantAndMultiplicativeGaussianErrorModel()]61 cls.pkpd_problem = chi.ProblemModellingController(62 cls.pkpd_model, cls.error_models,63 outputs=[64 'central.drug_concentration',65 'myokit.tumour_volume'])66 def test_bad_input(self):67 # Mechanistic model has wrong type68 mechanistic_model = 'wrong type'69 with self.assertRaisesRegex(TypeError, 'The mechanistic model'):70 chi.ProblemModellingController(71 mechanistic_model, self.error_model)72 # Error model has wrong type73 error_model = 'wrong type'74 with self.assertRaisesRegex(TypeError, 'Error models have to be'):75 chi.ProblemModellingController(76 self.pd_model, error_model)77 error_models = ['wrong', 'type']78 with self.assertRaisesRegex(TypeError, 'Error models have to be'):79 chi.ProblemModellingController(80 self.pd_model, error_models)81 # Wrong number of error models82 error_model = chi.ConstantAndMultiplicativeGaussianErrorModel()83 with self.assertRaisesRegex(ValueError, 'Wrong number of error'):84 chi.ProblemModellingController(85 self.pkpd_model, error_model)86 error_models = [87 chi.ConstantAndMultiplicativeGaussianErrorModel(),88 chi.ConstantAndMultiplicativeGaussianErrorModel()]89 with self.assertRaisesRegex(ValueError, 'Wrong number of error'):90 chi.ProblemModellingController(91 self.pd_model, error_models)92 def test_fix_parameters(self):93 # Test case I: PD model94 # Fix model parameters95 name_value_dict = dict({96 'myokit.drug_concentration': 0,97 'Sigma base': 1})98 self.pd_problem.fix_parameters(name_value_dict)99 self.assertEqual(self.pd_problem.get_n_parameters(), 5)100 param_names = self.pd_problem.get_parameter_names()101 self.assertEqual(len(param_names), 5)102 self.assertEqual(param_names[0], 'myokit.tumour_volume')103 self.assertEqual(param_names[1], 'myokit.kappa')104 self.assertEqual(param_names[2], 'myokit.lambda_0')105 self.assertEqual(param_names[3], 'myokit.lambda_1')106 self.assertEqual(param_names[4], 'Sigma rel.')107 # Free and fix a parameter108 name_value_dict = dict({109 'myokit.lambda_1': 2,110 'Sigma base': None})111 self.pd_problem.fix_parameters(name_value_dict)112 self.assertEqual(self.pd_problem.get_n_parameters(), 5)113 param_names = self.pd_problem.get_parameter_names()114 self.assertEqual(len(param_names), 5)115 self.assertEqual(param_names[0], 'myokit.tumour_volume')116 self.assertEqual(param_names[1], 'myokit.kappa')117 self.assertEqual(param_names[2], 'myokit.lambda_0')118 self.assertEqual(param_names[3], 'Sigma base')119 self.assertEqual(param_names[4], 'Sigma rel.')120 # Free all parameters again121 name_value_dict = dict({122 'myokit.lambda_1': None,123 'myokit.drug_concentration': None})124 self.pd_problem.fix_parameters(name_value_dict)125 self.assertEqual(self.pd_problem.get_n_parameters(), 7)126 param_names = self.pd_problem.get_parameter_names()127 self.assertEqual(len(param_names), 7)128 self.assertEqual(param_names[0], 'myokit.tumour_volume')129 self.assertEqual(param_names[1], 'myokit.drug_concentration')130 self.assertEqual(param_names[2], 'myokit.kappa')131 self.assertEqual(param_names[3], 'myokit.lambda_0')132 self.assertEqual(param_names[4], 'myokit.lambda_1')133 self.assertEqual(param_names[5], 'Sigma base')134 self.assertEqual(param_names[6], 'Sigma rel.')135 # Fix parameters before setting a population model136 problem = copy.copy(self.pd_problem)137 name_value_dict = dict({138 'myokit.tumour_volume': 1,139 'myokit.drug_concentration': 0,140 'myokit.kappa': 1,141 'myokit.lambda_1': 2})142 problem.fix_parameters(name_value_dict)143 problem.set_population_model(144 pop_models=[145 chi.HeterogeneousModel(),146 chi.PooledModel(),147 chi.LogNormalModel()])148 problem.set_data(149 self.data,150 output_observable_dict={'myokit.tumour_volume': 'Tumour volume'})151 n_ids = 3152 self.assertEqual(problem.get_n_parameters(), 2 * n_ids + 1 + 2)153 param_names = problem.get_parameter_names()154 self.assertEqual(len(param_names), 9)155 self.assertEqual(param_names[0], 'ID 0: myokit.lambda_0')156 self.assertEqual(param_names[1], 'ID 1: myokit.lambda_0')157 self.assertEqual(param_names[2], 'ID 2: myokit.lambda_0')158 self.assertEqual(param_names[3], 'Pooled Sigma base')159 self.assertEqual(param_names[4], 'ID 0: Sigma rel.')160 self.assertEqual(param_names[5], 'ID 1: Sigma rel.')161 self.assertEqual(param_names[6], 'ID 2: Sigma rel.')162 self.assertEqual(param_names[7], 'Mean log Sigma rel.')163 self.assertEqual(param_names[8], 'Std. log Sigma rel.')164 # Fix parameters after setting a population model165 # (Only population models can be fixed)166 name_value_dict = dict({167 'ID 1: myokit.lambda_0': 1,168 'ID 2: myokit.lambda_0': 4,169 'Pooled Sigma base': 2})170 problem.fix_parameters(name_value_dict)171 # self.assertEqual(problem.get_n_parameters(), 8)172 param_names = problem.get_parameter_names()173 self.assertEqual(len(param_names), 8)174 self.assertEqual(param_names[0], 'ID 0: myokit.lambda_0')175 self.assertEqual(param_names[1], 'ID 1: myokit.lambda_0')176 self.assertEqual(param_names[2], 'ID 2: myokit.lambda_0')177 self.assertEqual(param_names[3], 'ID 0: Sigma rel.')178 self.assertEqual(param_names[4], 'ID 1: Sigma rel.')179 self.assertEqual(param_names[5], 'ID 2: Sigma rel.')180 self.assertEqual(param_names[6], 'Mean log Sigma rel.')181 self.assertEqual(param_names[7], 'Std. log Sigma rel.')182 # Test case II: PKPD model183 # Fix model parameters184 name_value_dict = dict({185 'myokit.kappa': 0,186 'central.drug_concentration Sigma base': 1})187 self.pkpd_problem.fix_parameters(name_value_dict)188 self.assertEqual(self.pkpd_problem.get_n_parameters(), 9)189 param_names = self.pkpd_problem.get_parameter_names()190 self.assertEqual(len(param_names), 9)191 self.assertEqual(param_names[0], 'central.drug_amount')192 self.assertEqual(param_names[1], 'myokit.tumour_volume')193 self.assertEqual(param_names[2], 'central.size')194 self.assertEqual(param_names[3], 'myokit.critical_volume')195 self.assertEqual(param_names[4], 'myokit.elimination_rate')196 self.assertEqual(param_names[5], 'myokit.lambda')197 self.assertEqual(198 param_names[6], 'central.drug_concentration Sigma rel.')199 self.assertEqual(param_names[7], 'myokit.tumour_volume Sigma base')200 self.assertEqual(param_names[8], 'myokit.tumour_volume Sigma rel.')201 # Free and fix a parameter202 name_value_dict = dict({203 'myokit.lambda': 2,204 'myokit.kappa': None})205 self.pkpd_problem.fix_parameters(name_value_dict)206 self.assertEqual(self.pkpd_problem.get_n_parameters(), 9)207 param_names = self.pkpd_problem.get_parameter_names()208 self.assertEqual(len(param_names), 9)209 self.assertEqual(param_names[0], 'central.drug_amount')210 self.assertEqual(param_names[1], 'myokit.tumour_volume')211 self.assertEqual(param_names[2], 'central.size')212 self.assertEqual(param_names[3], 'myokit.critical_volume')213 self.assertEqual(param_names[4], 'myokit.elimination_rate')214 self.assertEqual(param_names[5], 'myokit.kappa')215 self.assertEqual(216 param_names[6], 'central.drug_concentration Sigma rel.')217 self.assertEqual(param_names[7], 'myokit.tumour_volume Sigma base')218 self.assertEqual(param_names[8], 'myokit.tumour_volume Sigma rel.')219 # Free all parameters again220 name_value_dict = dict({221 'myokit.lambda': None,222 'central.drug_concentration Sigma base': None})223 self.pkpd_problem.fix_parameters(name_value_dict)224 self.assertEqual(self.pkpd_problem.get_n_parameters(), 11)225 param_names = self.pkpd_problem.get_parameter_names()226 self.assertEqual(len(param_names), 11)227 self.assertEqual(param_names[0], 'central.drug_amount')228 self.assertEqual(param_names[1], 'myokit.tumour_volume')229 self.assertEqual(param_names[2], 'central.size')230 self.assertEqual(param_names[3], 'myokit.critical_volume')231 self.assertEqual(param_names[4], 'myokit.elimination_rate')232 self.assertEqual(param_names[5], 'myokit.kappa')233 self.assertEqual(param_names[6], 'myokit.lambda')234 self.assertEqual(235 param_names[7], 'central.drug_concentration Sigma base')236 self.assertEqual(237 param_names[8], 'central.drug_concentration Sigma rel.')238 self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')239 self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')240 def test_fix_parameters_bad_input(self):241 # Input is not a dictionary242 name_value_dict = 'Bad type'243 with self.assertRaisesRegex(ValueError, 'The name-value dictionary'):244 self.pd_problem.fix_parameters(name_value_dict)245 def test_get_covariate_names(self):246 # Test case I: PD model247 problem = copy.deepcopy(self.pd_problem)248 # I.1: No population model249 names = problem.get_covariate_names()250 self.assertEqual(len(names), 0)251 # I.2: Population model but no covariate population model252 pop_models = [chi.PooledModel()] * 7253 problem.set_population_model(pop_models)254 names = problem.get_covariate_names()255 self.assertEqual(len(names), 0)256 names = problem.get_covariate_names(unique=False)257 self.assertEqual(len(names), 7)258 self.assertEqual(names[0], [])259 self.assertEqual(names[1], [])260 self.assertEqual(names[2], [])261 self.assertEqual(names[3], [])262 self.assertEqual(names[3], [])263 self.assertEqual(names[3], [])264 self.assertEqual(names[3], [])265 # I.3: With covariate models266 cov_pop_model1 = chi.CovariatePopulationModel(267 chi.GaussianModel(),268 chi.LogNormalLinearCovariateModel(n_covariates=2)269 )270 cov_pop_model1.set_covariate_names(['Age', 'Sex'])271 cov_pop_model2 = chi.CovariatePopulationModel(272 chi.GaussianModel(),273 chi.LogNormalLinearCovariateModel(n_covariates=3)274 )275 cov_pop_model2.set_covariate_names(['SNP', 'Age', 'Height'])276 pop_models = [277 chi.PooledModel(),278 cov_pop_model1,279 chi.PooledModel(),280 cov_pop_model2,281 cov_pop_model1,282 chi.PooledModel(),283 chi.PooledModel()284 ]285 problem.set_population_model(pop_models)286 names = problem.get_covariate_names()287 self.assertEqual(len(names), 4)288 self.assertEqual(names[0], 'Age')289 self.assertEqual(names[1], 'Sex')290 self.assertEqual(names[2], 'SNP')291 self.assertEqual(names[3], 'Height')292 names = problem.get_covariate_names(unique=False)293 self.assertEqual(len(names), 7)294 self.assertEqual(names[0], [])295 self.assertEqual(names[1], ['Age', 'Sex'])296 self.assertEqual(names[2], [])297 self.assertEqual(names[3], ['SNP', 'Age', 'Height'])298 self.assertEqual(names[4], ['Age', 'Sex'])299 self.assertEqual(names[5], [])300 self.assertEqual(names[6], [])301 def test_get_dosing_regimens(self):302 # Test case I: PD problem303 problem = copy.deepcopy(self.pd_problem)304 # No data has been set305 regimens = problem.get_dosing_regimens()306 self.assertIsNone(regimens)307 # Set data, but because PD model, no dosing regimen can be set308 problem.set_data(self.data, {'myokit.tumour_volume': 'Tumour volume'})309 regimens = problem.get_dosing_regimens()310 self.assertIsNone(regimens)311 # Test case II: PKPD problem312 problem = copy.deepcopy(self.pkpd_problem)313 # No data has been set314 regimens = problem.get_dosing_regimens()315 self.assertIsNone(regimens)316 # Data has been set, but duration is ignored317 problem.set_data(318 self.data,319 output_observable_dict={320 'myokit.tumour_volume': 'Tumour volume',321 'central.drug_concentration': 'IL 6'},322 dose_duration_key=None)323 regimens = problem.get_dosing_regimens()324 self.assertIsInstance(regimens, dict)325 # Data has been set with duration information326 problem.set_data(327 self.data,328 output_observable_dict={329 'myokit.tumour_volume': 'Tumour volume',330 'central.drug_concentration': 'IL 6'})331 regimens = problem.get_dosing_regimens()332 self.assertIsInstance(regimens, dict)333 def test_get_log_prior(self):334 # Log-prior is extensively tested with get_log_posterior335 # method336 self.assertIsNone(self.pd_problem.get_log_prior())337 def test_get_log_posterior(self):338 # Test case I: Create posterior with no fixed parameters339 problem = copy.deepcopy(self.pd_problem)340 # Set data which does not provide measurements for all IDs341 problem.set_data(342 self.data,343 output_observable_dict={'myokit.tumour_volume': 'IL 6'})344 problem.set_log_prior([345 pints.HalfCauchyLogPrior(0, 1)]*7)346 # Get all posteriors347 posteriors = problem.get_log_posterior()348 self.assertEqual(len(posteriors), 2)349 self.assertEqual(posteriors[0].n_parameters(), 7)350 self.assertEqual(posteriors[0].get_id(), 'ID 0')351 self.assertEqual(posteriors[1].n_parameters(), 7)352 self.assertEqual(posteriors[1].get_id(), 'ID 1')353 # Set data that has measurements for all IDs354 problem.set_data(355 self.data,356 output_observable_dict={'myokit.tumour_volume': 'Tumour volume'})357 problem.set_log_prior([358 pints.HalfCauchyLogPrior(0, 1)]*7)359 # Get all posteriors360 posteriors = problem.get_log_posterior()361 self.assertEqual(len(posteriors), 3)362 self.assertEqual(posteriors[0].n_parameters(), 7)363 self.assertEqual(posteriors[0].get_id(), 'ID 0')364 self.assertEqual(posteriors[1].n_parameters(), 7)365 self.assertEqual(posteriors[1].get_id(), 'ID 1')366 self.assertEqual(posteriors[2].n_parameters(), 7)367 self.assertEqual(posteriors[2].get_id(), 'ID 2')368 # Get only one posterior369 posterior = problem.get_log_posterior(individual='0')370 self.assertIsInstance(posterior, chi.LogPosterior)371 self.assertEqual(posterior.n_parameters(), 7)372 self.assertEqual(posterior.get_id(), 'ID 0')373 # Test case II: Fix some parameters374 name_value_dict = dict({375 'myokit.drug_concentration': 0,376 'myokit.kappa': 1})377 problem.fix_parameters(name_value_dict)378 problem.set_log_prior([379 pints.HalfCauchyLogPrior(0, 1)]*5)380 # Get all posteriors381 posteriors = problem.get_log_posterior()382 self.assertEqual(len(posteriors), 3)383 self.assertEqual(posteriors[0].n_parameters(), 5)384 self.assertEqual(posteriors[0].get_id(), 'ID 0')385 self.assertEqual(posteriors[1].n_parameters(), 5)386 self.assertEqual(posteriors[1].get_id(), 'ID 1')387 self.assertEqual(posteriors[2].n_parameters(), 5)388 self.assertEqual(posteriors[2].get_id(), 'ID 2')389 # Get only one posterior390 posterior = problem.get_log_posterior(individual='1')391 self.assertIsInstance(posterior, chi.LogPosterior)392 self.assertEqual(posterior.n_parameters(), 5)393 self.assertEqual(posterior.get_id(), 'ID 1')394 # Set a population model395 cov_pop_model = chi.CovariatePopulationModel(396 chi.GaussianModel(),397 chi.LogNormalLinearCovariateModel(n_covariates=1)398 )399 cov_pop_model.set_covariate_names(['Age'], True)400 pop_models = [401 chi.PooledModel(),402 chi.HeterogeneousModel(),403 chi.PooledModel(),404 chi.PooledModel(),405 cov_pop_model]406 problem.set_population_model(pop_models)407 problem.set_log_prior([408 pints.HalfCauchyLogPrior(0, 1)]*9)409 posterior = problem.get_log_posterior()410 self.assertIsInstance(posterior, chi.HierarchicalLogPosterior)411 self.assertEqual(posterior.n_parameters(), 12)412 names = posterior.get_parameter_names()413 ids = posterior.get_id()414 self.assertEqual(len(names), 12)415 self.assertEqual(len(ids), 12)416 self.assertEqual(names[0], 'Pooled myokit.tumour_volume')417 self.assertIsNone(ids[0])418 self.assertEqual(names[1], 'myokit.lambda_0')419 self.assertEqual(ids[1], 'ID 0')420 self.assertEqual(names[2], 'myokit.lambda_0')421 self.assertEqual(ids[2], 'ID 1')422 self.assertEqual(names[3], 'myokit.lambda_0')423 self.assertEqual(ids[3], 'ID 2')424 self.assertEqual(names[4], 'Pooled myokit.lambda_1')425 self.assertIsNone(ids[4])426 self.assertEqual(names[5], 'Pooled Sigma base')427 self.assertIsNone(ids[5])428 self.assertEqual(names[6], 'Sigma rel. Eta')429 self.assertEqual(ids[6], 'ID 0')430 self.assertEqual(names[7], 'Sigma rel. Eta')431 self.assertEqual(ids[7], 'ID 1')432 self.assertEqual(names[8], 'Sigma rel. Eta')433 self.assertEqual(ids[8], 'ID 2')434 self.assertEqual(names[9], 'Base mean log Sigma rel.')435 self.assertIsNone(ids[9])436 self.assertEqual(names[10], 'Std. log Sigma rel.')437 self.assertIsNone(ids[10])438 self.assertEqual(names[11], 'Shift Age Sigma rel.')439 self.assertIsNone(ids[11])440 # Make sure that selecting an individual is ignored for population441 # models442 posterior = problem.get_log_posterior(individual='some individual')443 self.assertIsInstance(posterior, chi.HierarchicalLogPosterior)444 self.assertEqual(posterior.n_parameters(), 12)445 names = posterior.get_parameter_names()446 ids = posterior.get_id()447 self.assertEqual(len(names), 12)448 self.assertEqual(len(ids), 12)449 self.assertEqual(names[0], 'Pooled myokit.tumour_volume')450 self.assertIsNone(ids[0])451 self.assertEqual(names[1], 'myokit.lambda_0')452 self.assertEqual(ids[1], 'ID 0')453 self.assertEqual(names[2], 'myokit.lambda_0')454 self.assertEqual(ids[2], 'ID 1')455 self.assertEqual(names[3], 'myokit.lambda_0')456 self.assertEqual(ids[3], 'ID 2')457 self.assertEqual(names[4], 'Pooled myokit.lambda_1')458 self.assertIsNone(ids[4])459 self.assertEqual(names[5], 'Pooled Sigma base')460 self.assertIsNone(ids[5])461 self.assertEqual(names[6], 'Sigma rel. Eta')462 self.assertEqual(ids[6], 'ID 0')463 self.assertEqual(names[7], 'Sigma rel. Eta')464 self.assertEqual(ids[7], 'ID 1')465 self.assertEqual(names[8], 'Sigma rel. Eta')466 self.assertEqual(ids[8], 'ID 2')467 self.assertEqual(names[9], 'Base mean log Sigma rel.')468 self.assertIsNone(ids[9])469 self.assertEqual(names[10], 'Std. log Sigma rel.')470 self.assertIsNone(ids[10])471 self.assertEqual(names[11], 'Shift Age Sigma rel.')472 self.assertIsNone(ids[11])473 def test_get_log_posteriors_bad_input(self):474 problem = copy.deepcopy(self.pd_problem)475 # No log-prior has been set476 problem.set_data(477 self.data,478 output_observable_dict={'myokit.tumour_volume': 'Tumour volume'})479 with self.assertRaisesRegex(ValueError, 'The log-prior has not'):480 problem.get_log_posterior()481 # The selected individual does not exist482 individual = 'Not existent'483 problem.set_log_prior([pints.HalfCauchyLogPrior(0, 1)]*7)484 with self.assertRaisesRegex(ValueError, 'The individual cannot'):485 problem.get_log_posterior(individual)486 def test_get_n_parameters(self):487 # Test case I: PD model488 # Test case I.1: No population model489 # Test default flag490 problem = copy.deepcopy(self.pd_problem)491 n_parameters = problem.get_n_parameters()492 self.assertEqual(n_parameters, 7)493 # Test exclude population model True494 n_parameters = problem.get_n_parameters(exclude_pop_model=True)495 self.assertEqual(n_parameters, 7)496 # Test exclude bottom-level model True497 n_parameters = problem.get_n_parameters(exclude_bottom_level=True)498 self.assertEqual(n_parameters, 7)499 # Test case I.2: Population model500 pop_models = [501 chi.PooledModel(),502 chi.PooledModel(),503 chi.HeterogeneousModel(),504 chi.PooledModel(),505 chi.PooledModel(),506 chi.LogNormalModel(),507 chi.LogNormalModel()]508 problem.set_population_model(pop_models)509 n_parameters = problem.get_n_parameters()510 self.assertEqual(n_parameters, 8)511 # Test exclude population model True512 n_parameters = problem.get_n_parameters(exclude_pop_model=True)513 self.assertEqual(n_parameters, 7)514 # Test exclude bottom-level model True515 n_parameters = problem.get_n_parameters(exclude_bottom_level=True)516 self.assertEqual(n_parameters, 8)517 # Test case I.3: Set data518 problem.set_data(519 self.data,520 output_observable_dict={'myokit.tumour_volume': 'Tumour volume'})521 n_parameters = problem.get_n_parameters()522 self.assertEqual(n_parameters, 17)523 # Test exclude population model True524 n_parameters = problem.get_n_parameters(exclude_pop_model=True)525 self.assertEqual(n_parameters, 7)526 # Test exclude bottom-level model True527 n_parameters = problem.get_n_parameters(exclude_bottom_level=True)528 self.assertEqual(n_parameters, 11)529 # Test case II: PKPD model530 # Test case II.1: No population model531 # Test default flag532 problem = copy.deepcopy(self.pkpd_problem)533 n_parameters = problem.get_n_parameters()534 self.assertEqual(n_parameters, 11)535 # Test exclude population model True536 n_parameters = problem.get_n_parameters(exclude_pop_model=True)537 self.assertEqual(n_parameters, 11)538 # Test exclude bottom-level model True539 n_parameters = problem.get_n_parameters(exclude_bottom_level=True)540 self.assertEqual(n_parameters, 11)541 # Test case II.2: Population model542 pop_models = [543 chi.PooledModel(),544 chi.PooledModel(),545 chi.HeterogeneousModel(),546 chi.PooledModel(),547 chi.PooledModel(),548 chi.LogNormalModel(),549 chi.LogNormalModel(),550 chi.PooledModel(),551 chi.PooledModel(),552 chi.PooledModel(),553 chi.PooledModel()]554 problem.set_population_model(pop_models)555 n_parameters = problem.get_n_parameters()556 self.assertEqual(n_parameters, 12)557 # Test exclude population model True558 n_parameters = problem.get_n_parameters(exclude_pop_model=True)559 self.assertEqual(n_parameters, 11)560 # Test exclude bottom-level model True561 n_parameters = problem.get_n_parameters(exclude_bottom_level=True)562 self.assertEqual(n_parameters, 12)563 # Test case II.3: Set data564 problem.set_data(565 self.data,566 output_observable_dict={567 'myokit.tumour_volume': 'Tumour volume',568 'central.drug_concentration': 'IL 6'})569 n_parameters = problem.get_n_parameters()570 self.assertEqual(n_parameters, 21)571 # Test exclude population model True572 n_parameters = problem.get_n_parameters(exclude_pop_model=True)573 self.assertEqual(n_parameters, 11)574 # Test exclude bottom-level model True575 n_parameters = problem.get_n_parameters(exclude_bottom_level=True)576 self.assertEqual(n_parameters, 15)577 def test_get_parameter_names(self):578 # Test case I: PD model579 problem = copy.deepcopy(self.pd_problem)580 # Test case I.1: No population model581 # Test default flag582 param_names = problem.get_parameter_names()583 self.assertEqual(len(param_names), 7)584 self.assertEqual(param_names[0], 'myokit.tumour_volume')585 self.assertEqual(param_names[1], 'myokit.drug_concentration')586 self.assertEqual(param_names[2], 'myokit.kappa')587 self.assertEqual(param_names[3], 'myokit.lambda_0')588 self.assertEqual(param_names[4], 'myokit.lambda_1')589 self.assertEqual(param_names[5], 'Sigma base')590 self.assertEqual(param_names[6], 'Sigma rel.')591 # Check that also works with exclude pop params flag592 param_names = problem.get_parameter_names(exclude_pop_model=True)593 self.assertEqual(len(param_names), 7)594 self.assertEqual(param_names[0], 'myokit.tumour_volume')595 self.assertEqual(param_names[1], 'myokit.drug_concentration')596 self.assertEqual(param_names[2], 'myokit.kappa')597 self.assertEqual(param_names[3], 'myokit.lambda_0')598 self.assertEqual(param_names[4], 'myokit.lambda_1')599 self.assertEqual(param_names[5], 'Sigma base')600 self.assertEqual(param_names[6], 'Sigma rel.')601 # Check that also works with exclude bottom-level flag602 param_names = problem.get_parameter_names(exclude_bottom_level=True)603 self.assertEqual(len(param_names), 7)604 self.assertEqual(param_names[0], 'myokit.tumour_volume')605 self.assertEqual(param_names[1], 'myokit.drug_concentration')606 self.assertEqual(param_names[2], 'myokit.kappa')607 self.assertEqual(param_names[3], 'myokit.lambda_0')608 self.assertEqual(param_names[4], 'myokit.lambda_1')609 self.assertEqual(param_names[5], 'Sigma base')610 self.assertEqual(param_names[6], 'Sigma rel.')611 # Test case I.2: Population model612 cov_population_model = chi.CovariatePopulationModel(613 chi.GaussianModel(), chi.LogNormalLinearCovariateModel())614 pop_models = [615 chi.PooledModel(),616 chi.PooledModel(),617 chi.HeterogeneousModel(),618 chi.PooledModel(),619 chi.PooledModel(),620 cov_population_model,621 chi.LogNormalModel()]622 problem.set_population_model(pop_models)623 param_names = problem.get_parameter_names()624 self.assertEqual(len(param_names), 8)625 self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')626 self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')627 self.assertEqual(param_names[2], 'Pooled myokit.lambda_0')628 self.assertEqual(param_names[3], 'Pooled myokit.lambda_1')629 self.assertEqual(param_names[4], 'Base mean log Sigma base')630 self.assertEqual(param_names[5], 'Std. log Sigma base')631 self.assertEqual(param_names[6], 'Mean log Sigma rel.')632 self.assertEqual(param_names[7], 'Std. log Sigma rel.')633 # Test exclude population model True634 param_names = problem.get_parameter_names(exclude_pop_model=True)635 self.assertEqual(len(param_names), 7)636 self.assertEqual(param_names[0], 'myokit.tumour_volume')637 self.assertEqual(param_names[1], 'myokit.drug_concentration')638 self.assertEqual(param_names[2], 'myokit.kappa')639 self.assertEqual(param_names[3], 'myokit.lambda_0')640 self.assertEqual(param_names[4], 'myokit.lambda_1')641 self.assertEqual(param_names[5], 'Sigma base')642 self.assertEqual(param_names[6], 'Sigma rel.')643 # Test exclude bottom-level True644 param_names = problem.get_parameter_names(exclude_bottom_level=True)645 self.assertEqual(len(param_names), 8)646 self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')647 self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')648 self.assertEqual(param_names[2], 'Pooled myokit.lambda_0')649 self.assertEqual(param_names[3], 'Pooled myokit.lambda_1')650 self.assertEqual(param_names[4], 'Base mean log Sigma base')651 self.assertEqual(param_names[5], 'Std. log Sigma base')652 self.assertEqual(param_names[6], 'Mean log Sigma rel.')653 self.assertEqual(param_names[7], 'Std. log Sigma rel.')654 # Test case I.3: Set data655 problem.set_data(656 self.data,657 output_observable_dict={'myokit.tumour_volume': 'Tumour volume'})658 param_names = problem.get_parameter_names()659 self.assertEqual(len(param_names), 17)660 self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')661 self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')662 self.assertEqual(param_names[2], 'ID 0: myokit.kappa')663 self.assertEqual(param_names[3], 'ID 1: myokit.kappa')664 self.assertEqual(param_names[4], 'ID 2: myokit.kappa')665 self.assertEqual(param_names[5], 'Pooled myokit.lambda_0')666 self.assertEqual(param_names[6], 'Pooled myokit.lambda_1')667 self.assertEqual(param_names[7], 'ID 0: Sigma base Eta')668 self.assertEqual(param_names[8], 'ID 1: Sigma base Eta')669 self.assertEqual(param_names[9], 'ID 2: Sigma base Eta')670 self.assertEqual(param_names[10], 'Base mean log Sigma base')671 self.assertEqual(param_names[11], 'Std. log Sigma base')672 self.assertEqual(param_names[12], 'ID 0: Sigma rel.')673 self.assertEqual(param_names[13], 'ID 1: Sigma rel.')674 self.assertEqual(param_names[14], 'ID 2: Sigma rel.')675 self.assertEqual(param_names[15], 'Mean log Sigma rel.')676 self.assertEqual(param_names[16], 'Std. log Sigma rel.')677 # Test exclude population model True678 param_names = problem.get_parameter_names(exclude_pop_model=True)679 self.assertEqual(len(param_names), 7)680 self.assertEqual(param_names[0], 'myokit.tumour_volume')681 self.assertEqual(param_names[1], 'myokit.drug_concentration')682 self.assertEqual(param_names[2], 'myokit.kappa')683 self.assertEqual(param_names[3], 'myokit.lambda_0')684 self.assertEqual(param_names[4], 'myokit.lambda_1')685 self.assertEqual(param_names[5], 'Sigma base')686 self.assertEqual(param_names[6], 'Sigma rel.')687 # Test exclude bottom-level True688 param_names = problem.get_parameter_names(exclude_bottom_level=True)689 self.assertEqual(len(param_names), 11)690 self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')691 self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')692 self.assertEqual(param_names[2], 'ID 0: myokit.kappa')693 self.assertEqual(param_names[3], 'ID 1: myokit.kappa')694 self.assertEqual(param_names[4], 'ID 2: myokit.kappa')695 self.assertEqual(param_names[5], 'Pooled myokit.lambda_0')696 self.assertEqual(param_names[6], 'Pooled myokit.lambda_1')697 self.assertEqual(param_names[7], 'Base mean log Sigma base')698 self.assertEqual(param_names[8], 'Std. log Sigma base')699 self.assertEqual(param_names[9], 'Mean log Sigma rel.')700 self.assertEqual(param_names[10], 'Std. log Sigma rel.')701 # Test case II: PKPD model702 problem = copy.deepcopy(self.pkpd_problem)703 # Test case II.1: No population model704 # Test default flag705 param_names = problem.get_parameter_names()706 self.assertEqual(len(param_names), 11)707 self.assertEqual(param_names[0], 'central.drug_amount')708 self.assertEqual(param_names[1], 'myokit.tumour_volume')709 self.assertEqual(param_names[2], 'central.size')710 self.assertEqual(param_names[3], 'myokit.critical_volume')711 self.assertEqual(param_names[4], 'myokit.elimination_rate')712 self.assertEqual(param_names[5], 'myokit.kappa')713 self.assertEqual(param_names[6], 'myokit.lambda')714 self.assertEqual(715 param_names[7], 'central.drug_concentration Sigma base')716 self.assertEqual(717 param_names[8], 'central.drug_concentration Sigma rel.')718 self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')719 self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')720 # Test exclude population model True721 param_names = problem.get_parameter_names(exclude_pop_model=True)722 self.assertEqual(len(param_names), 11)723 self.assertEqual(param_names[0], 'central.drug_amount')724 self.assertEqual(param_names[1], 'myokit.tumour_volume')725 self.assertEqual(param_names[2], 'central.size')726 self.assertEqual(param_names[3], 'myokit.critical_volume')727 self.assertEqual(param_names[4], 'myokit.elimination_rate')728 self.assertEqual(param_names[5], 'myokit.kappa')729 self.assertEqual(param_names[6], 'myokit.lambda')730 self.assertEqual(731 param_names[7], 'central.drug_concentration Sigma base')732 self.assertEqual(733 param_names[8], 'central.drug_concentration Sigma rel.')734 self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')735 self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')736 # Test exclude population model True737 param_names = problem.get_parameter_names(exclude_bottom_level=True)738 self.assertEqual(len(param_names), 11)739 self.assertEqual(param_names[0], 'central.drug_amount')740 self.assertEqual(param_names[1], 'myokit.tumour_volume')741 self.assertEqual(param_names[2], 'central.size')742 self.assertEqual(param_names[3], 'myokit.critical_volume')743 self.assertEqual(param_names[4], 'myokit.elimination_rate')744 self.assertEqual(param_names[5], 'myokit.kappa')745 self.assertEqual(param_names[6], 'myokit.lambda')746 self.assertEqual(747 param_names[7], 'central.drug_concentration Sigma base')748 self.assertEqual(749 param_names[8], 'central.drug_concentration Sigma rel.')750 self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')751 self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')752 # Test case II.2: Population model753 cov_pop_model = chi.CovariatePopulationModel(754 chi.GaussianModel(),755 chi.LogNormalLinearCovariateModel(n_covariates=1)756 )757 cov_pop_model.set_covariate_names(['Age'], True)758 pop_models = [759 chi.PooledModel(),760 chi.PooledModel(),761 chi.HeterogeneousModel(),762 chi.PooledModel(),763 chi.PooledModel(),764 chi.LogNormalModel(),765 chi.LogNormalModel(),766 chi.PooledModel(),767 cov_pop_model,768 chi.PooledModel(),769 chi.PooledModel()]770 problem.set_population_model(pop_models)771 param_names = problem.get_parameter_names()772 self.assertEqual(len(param_names), 14)773 self.assertEqual(param_names[0], 'Pooled central.drug_amount')774 self.assertEqual(param_names[1], 'Pooled myokit.tumour_volume')775 self.assertEqual(param_names[2], 'Pooled myokit.critical_volume')776 self.assertEqual(param_names[3], 'Pooled myokit.elimination_rate')777 self.assertEqual(param_names[4], 'Mean log myokit.kappa')778 self.assertEqual(param_names[5], 'Std. log myokit.kappa')779 self.assertEqual(param_names[6], 'Mean log myokit.lambda')780 self.assertEqual(param_names[7], 'Std. log myokit.lambda')781 self.assertEqual(782 param_names[8], 'Pooled central.drug_concentration Sigma base')783 self.assertEqual(784 param_names[9],785 'Base mean log central.drug_concentration Sigma rel.')786 self.assertEqual(787 param_names[10], 'Std. log central.drug_concentration Sigma rel.')788 self.assertEqual(789 param_names[11], 'Shift Age central.drug_concentration Sigma rel.')790 self.assertEqual(791 param_names[12], 'Pooled myokit.tumour_volume Sigma base')792 self.assertEqual(793 param_names[13], 'Pooled myokit.tumour_volume Sigma rel.')794 # Test exclude population model True795 param_names = problem.get_parameter_names(exclude_pop_model=True)796 self.assertEqual(len(param_names), 11)797 self.assertEqual(param_names[0], 'central.drug_amount')798 self.assertEqual(param_names[1], 'myokit.tumour_volume')799 self.assertEqual(param_names[2], 'central.size')800 self.assertEqual(param_names[3], 'myokit.critical_volume')801 self.assertEqual(param_names[4], 'myokit.elimination_rate')802 self.assertEqual(param_names[5], 'myokit.kappa')803 self.assertEqual(param_names[6], 'myokit.lambda')804 self.assertEqual(805 param_names[7], 'central.drug_concentration Sigma base')806 self.assertEqual(807 param_names[8], 'central.drug_concentration Sigma rel.')808 self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')809 self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')810 # Test exclude bottom-level True811 param_names = problem.get_parameter_names(exclude_bottom_level=True)812 self.assertEqual(len(param_names), 14)813 self.assertEqual(param_names[0], 'Pooled central.drug_amount')814 self.assertEqual(param_names[1], 'Pooled myokit.tumour_volume')815 self.assertEqual(param_names[2], 'Pooled myokit.critical_volume')816 self.assertEqual(param_names[3], 'Pooled myokit.elimination_rate')817 self.assertEqual(param_names[4], 'Mean log myokit.kappa')818 self.assertEqual(param_names[5], 'Std. log myokit.kappa')819 self.assertEqual(param_names[6], 'Mean log myokit.lambda')820 self.assertEqual(param_names[7], 'Std. log myokit.lambda')821 self.assertEqual(822 param_names[8], 'Pooled central.drug_concentration Sigma base')823 self.assertEqual(824 param_names[9],825 'Base mean log central.drug_concentration Sigma rel.')826 self.assertEqual(827 param_names[10], 'Std. log central.drug_concentration Sigma rel.')828 self.assertEqual(829 param_names[11], 'Shift Age central.drug_concentration Sigma rel.')830 self.assertEqual(831 param_names[12], 'Pooled myokit.tumour_volume Sigma base')832 self.assertEqual(833 param_names[13], 'Pooled myokit.tumour_volume Sigma rel.')834 # Test case II.3: Set data835 problem.set_data(836 self.data,837 output_observable_dict={838 'myokit.tumour_volume': 'Tumour volume',839 'central.drug_concentration': 'IL 6'})840 param_names = problem.get_parameter_names()841 self.assertEqual(len(param_names), 26)842 self.assertEqual(param_names[0], 'Pooled central.drug_amount')843 self.assertEqual(param_names[1], 'Pooled myokit.tumour_volume')844 self.assertEqual(param_names[2], 'ID 0: central.size')845 self.assertEqual(param_names[3], 'ID 1: central.size')846 self.assertEqual(param_names[4], 'ID 2: central.size')847 self.assertEqual(param_names[5], 'Pooled myokit.critical_volume')848 self.assertEqual(param_names[6], 'Pooled myokit.elimination_rate')849 self.assertEqual(param_names[7], 'ID 0: myokit.kappa')850 self.assertEqual(param_names[8], 'ID 1: myokit.kappa')851 self.assertEqual(param_names[9], 'ID 2: myokit.kappa')852 self.assertEqual(param_names[10], 'Mean log myokit.kappa')853 self.assertEqual(param_names[11], 'Std. log myokit.kappa')854 self.assertEqual(param_names[12], 'ID 0: myokit.lambda')855 self.assertEqual(param_names[13], 'ID 1: myokit.lambda')856 self.assertEqual(param_names[14], 'ID 2: myokit.lambda')857 self.assertEqual(param_names[15], 'Mean log myokit.lambda')858 self.assertEqual(param_names[16], 'Std. log myokit.lambda')859 self.assertEqual(860 param_names[17], 'Pooled central.drug_concentration Sigma base')861 self.assertEqual(862 param_names[18],863 'ID 0: central.drug_concentration Sigma rel. Eta')864 self.assertEqual(865 param_names[19],866 'ID 1: central.drug_concentration Sigma rel. Eta')867 self.assertEqual(868 param_names[20],869 'ID 2: central.drug_concentration Sigma rel. Eta')870 self.assertEqual(871 param_names[21],872 'Base mean log central.drug_concentration Sigma rel.')873 self.assertEqual(874 param_names[22], 'Std. log central.drug_concentration Sigma rel.')875 self.assertEqual(876 param_names[23], 'Shift Age central.drug_concentration Sigma rel.')877 self.assertEqual(878 param_names[24], 'Pooled myokit.tumour_volume Sigma base')879 self.assertEqual(880 param_names[25], 'Pooled myokit.tumour_volume Sigma rel.')881 # Test exclude population model True882 param_names = problem.get_parameter_names(exclude_pop_model=True)883 self.assertEqual(len(param_names), 11)884 self.assertEqual(param_names[0], 'central.drug_amount')885 self.assertEqual(param_names[1], 'myokit.tumour_volume')886 self.assertEqual(param_names[2], 'central.size')887 self.assertEqual(param_names[3], 'myokit.critical_volume')888 self.assertEqual(param_names[4], 'myokit.elimination_rate')889 self.assertEqual(param_names[5], 'myokit.kappa')890 self.assertEqual(param_names[6], 'myokit.lambda')891 self.assertEqual(892 param_names[7], 'central.drug_concentration Sigma base')893 self.assertEqual(894 param_names[8], 'central.drug_concentration Sigma rel.')895 self.assertEqual(param_names[9], 'myokit.tumour_volume Sigma base')896 self.assertEqual(param_names[10], 'myokit.tumour_volume Sigma rel.')897 # Test exclude bottom-level True898 param_names = problem.get_parameter_names(exclude_bottom_level=True)899 self.assertEqual(len(param_names), 17)900 self.assertEqual(param_names[0], 'Pooled central.drug_amount')901 self.assertEqual(param_names[1], 'Pooled myokit.tumour_volume')902 self.assertEqual(param_names[2], 'ID 0: central.size')903 self.assertEqual(param_names[3], 'ID 1: central.size')904 self.assertEqual(param_names[4], 'ID 2: central.size')905 self.assertEqual(param_names[5], 'Pooled myokit.critical_volume')906 self.assertEqual(param_names[6], 'Pooled myokit.elimination_rate')907 self.assertEqual(param_names[7], 'Mean log myokit.kappa')908 self.assertEqual(param_names[8], 'Std. log myokit.kappa')909 self.assertEqual(param_names[9], 'Mean log myokit.lambda')910 self.assertEqual(param_names[10], 'Std. log myokit.lambda')911 self.assertEqual(912 param_names[11], 'Pooled central.drug_concentration Sigma base')913 self.assertEqual(914 param_names[12],915 'Base mean log central.drug_concentration Sigma rel.')916 self.assertEqual(917 param_names[13], 'Std. log central.drug_concentration Sigma rel.')918 self.assertEqual(919 param_names[14], 'Shift Age central.drug_concentration Sigma rel.')920 self.assertEqual(921 param_names[15], 'Pooled myokit.tumour_volume Sigma base')922 self.assertEqual(923 param_names[16], 'Pooled myokit.tumour_volume Sigma rel.')924 def test_get_predictive_model(self):925 # Test case I: PD model926 problem = copy.deepcopy(self.pd_problem)927 # Test case I.1: No population model928 predictive_model = problem.get_predictive_model()929 self.assertIsInstance(predictive_model, chi.PredictiveModel)930 # Exclude population model931 predictive_model = problem.get_predictive_model(932 exclude_pop_model=True)933 self.assertIsInstance(predictive_model, chi.PredictiveModel)934 # Test case I.2: Population model935 problem.set_population_model([936 chi.PooledModel(),937 chi.PooledModel(),938 chi.HeterogeneousModel(),939 chi.PooledModel(),940 chi.PooledModel(),941 chi.LogNormalModel(),942 chi.LogNormalModel()])943 predictive_model = problem.get_predictive_model()944 self.assertIsInstance(945 predictive_model, chi.PopulationPredictiveModel)946 # Exclude population model947 predictive_model = problem.get_predictive_model(948 exclude_pop_model=True)949 self.assertNotIsInstance(950 predictive_model, chi.PopulationPredictiveModel)951 self.assertIsInstance(predictive_model, chi.PredictiveModel)952 # Test case II: PKPD model953 problem = copy.deepcopy(self.pkpd_problem)954 # Test case II.1: No population model955 predictive_model = problem.get_predictive_model()956 self.assertIsInstance(predictive_model, chi.PredictiveModel)957 # Exclude population model958 predictive_model = problem.get_predictive_model(959 exclude_pop_model=True)960 self.assertIsInstance(predictive_model, chi.PredictiveModel)961 # Test case II.2: Population model962 problem.set_population_model([963 chi.PooledModel(),964 chi.PooledModel(),965 chi.HeterogeneousModel(),966 chi.PooledModel(),967 chi.PooledModel(),968 chi.LogNormalModel(),969 chi.LogNormalModel(),970 chi.PooledModel(),971 chi.PooledModel(),972 chi.PooledModel(),973 chi.PooledModel()])974 predictive_model = problem.get_predictive_model()975 self.assertIsInstance(976 predictive_model, chi.PopulationPredictiveModel)977 # Exclude population model978 predictive_model = problem.get_predictive_model(979 exclude_pop_model=True)980 self.assertNotIsInstance(981 predictive_model, chi.PopulationPredictiveModel)982 self.assertIsInstance(predictive_model, chi.PredictiveModel)983 def test_set_data(self):984 # Set data with explicit output-observable map985 problem = copy.deepcopy(self.pd_problem)986 output_observable_dict = {'myokit.tumour_volume': 'Tumour volume'}987 problem.set_data(self.data, output_observable_dict)988 # Set data with implicit output-observable map989 mask = self.data['Observable'] == 'Tumour volume'990 problem.set_data(self.data[mask])991 # Set data with explicit covariate mapping992 cov_pop_model = chi.CovariatePopulationModel(993 chi.GaussianModel(),994 chi.LogNormalLinearCovariateModel(n_covariates=1)995 )996 cov_pop_model.set_covariate_names(['Sex'], True)997 pop_models = [cov_pop_model] * 7998 problem.set_population_model(pop_models)999 covariate_dict = {'Sex': 'Age'}1000 problem.set_data(self.data, output_observable_dict, covariate_dict)1001 def test_set_data_bad_input(self):1002 # Data has the wrong type1003 data = 'Wrong type'1004 with self.assertRaisesRegex(TypeError, 'Data has to be a'):1005 self.pd_problem.set_data(data)1006 # Data has the wrong ID key1007 data = self.data.rename(columns={'ID': 'Some key'})1008 with self.assertRaisesRegex(ValueError, 'Data does not have the'):1009 self.pkpd_problem.set_data(data)1010 # Data has the wrong time key1011 data = self.data.rename(columns={'Time': 'Some key'})1012 with self.assertRaisesRegex(ValueError, 'Data does not have the'):1013 self.pkpd_problem.set_data(data)1014 # Data has the wrong observable key1015 data = self.data.rename(columns={'Observable': 'Some key'})1016 with self.assertRaisesRegex(ValueError, 'Data does not have the'):1017 self.pkpd_problem.set_data(data)1018 # Data has the wrong value key1019 data = self.data.rename(columns={'Value': 'Some key'})1020 with self.assertRaisesRegex(ValueError, 'Data does not have the'):1021 self.pkpd_problem.set_data(data)1022 # Data has the wrong dose key1023 data = self.data.rename(columns={'Dose': 'Some key'})1024 with self.assertRaisesRegex(ValueError, 'Data does not have the'):1025 self.pkpd_problem.set_data(data)1026 # Data has the wrong duration key1027 data = self.data.rename(columns={'Duration': 'Some key'})1028 with self.assertRaisesRegex(ValueError, 'Data does not have the'):1029 self.pkpd_problem.set_data(data)1030 # The output-observable map does not contain a model output1031 output_observable_dict = {'some output': 'some observable'}1032 with self.assertRaisesRegex(ValueError, 'The output <central.drug'):1033 self.pkpd_problem.set_data(self.data, output_observable_dict)1034 # The output-observable map references a observable that is not in the1035 # dataframe1036 output_observable_dict = {'myokit.tumour_volume': 'some observable'}1037 with self.assertRaisesRegex(ValueError, 'The observable <some'):1038 self.pd_problem.set_data(self.data, output_observable_dict)1039 # The model outputs and dataframe observable cannot be trivially mapped1040 with self.assertRaisesRegex(ValueError, 'The observable <central.'):1041 self.pkpd_problem.set_data(self.data)1042 # Covariate map does not contain all model covariates1043 problem = copy.deepcopy(self.pd_problem)1044 cov_pop_model1 = chi.CovariatePopulationModel(1045 chi.GaussianModel(),1046 chi.LogNormalLinearCovariateModel(n_covariates=1)1047 )1048 cov_pop_model1.set_covariate_names(['Age'], True)1049 cov_pop_model2 = chi.CovariatePopulationModel(1050 chi.GaussianModel(),1051 chi.LogNormalLinearCovariateModel(n_covariates=1)1052 )1053 cov_pop_model2.set_covariate_names(['Sex'], True)1054 pop_models = [cov_pop_model1] * 4 + [cov_pop_model2] * 31055 problem.set_population_model(pop_models)1056 output_observable_dict = {'myokit.tumour_volume': 'Tumour volume'}1057 covariate_dict = {'Age': 'Age', 'Something': 'else'}1058 with self.assertRaisesRegex(ValueError, 'The covariate <Sex> could'):1059 problem.set_data(1060 self.data,1061 output_observable_dict=output_observable_dict,1062 covariate_dict=covariate_dict)1063 # Covariate dict maps to covariate that is not in the dataframe1064 covariate_dict = {'Age': 'Age', 'Sex': 'Does not exist'}1065 with self.assertRaisesRegex(ValueError, 'The covariate <Does not ex'):1066 problem.set_data(1067 self.data,1068 output_observable_dict=output_observable_dict,1069 covariate_dict=covariate_dict)1070 # There are no covariate values provided for an ID1071 data = self.data.copy()1072 mask = (data.ID == 1) | (data.Observable == 'Age')1073 data.loc[mask, 'Value'] = np.nan1074 pop_models = [cov_pop_model1] * 71075 problem.set_population_model(pop_models)1076 with self.assertRaisesRegex(ValueError, 'There are either 0 or more'):1077 problem.set_data(1078 data,1079 output_observable_dict=output_observable_dict)1080 # There is more than one covariate value provided for an ID1081 data = self.data.copy()1082 mask = data.Observable == 'Age'1083 data.loc[mask, 'ID'] = 01084 pop_models = [cov_pop_model1] * 71085 problem.set_population_model(pop_models)1086 with self.assertRaisesRegex(ValueError, 'There are either 0 or more'):1087 problem.set_data(1088 data,1089 output_observable_dict=output_observable_dict)1090 def test_set_log_prior(self):1091 # Test case I: PD model1092 problem = copy.deepcopy(self.pd_problem)1093 problem.set_data(self.data, {'myokit.tumour_volume': 'Tumour volume'})1094 log_priors = [pints.HalfCauchyLogPrior(0, 1)] * 71095 # Map priors to parameters automatically1096 problem.set_log_prior(log_priors)1097 # Specify prior parameter map explicitly1098 param_names = [1099 'myokit.kappa',1100 'Sigma base',1101 'Sigma rel.',1102 'myokit.tumour_volume',1103 'myokit.lambda_1',1104 'myokit.drug_concentration',1105 'myokit.lambda_0']1106 problem.set_log_prior(log_priors, param_names)1107 def test_set_log_prior_bad_input(self):1108 problem = copy.deepcopy(self.pd_problem)1109 # No data has been set1110 with self.assertRaisesRegex(ValueError, 'The data has not'):1111 problem.set_log_prior('some prior')1112 # Wrong log-prior type1113 problem.set_data(self.data, {'myokit.tumour_volume': 'Tumour volume'})1114 log_priors = ['Wrong', 'type']1115 with self.assertRaisesRegex(ValueError, 'All marginal log-priors'):1116 problem.set_log_prior(log_priors)1117 # Number of log priors does not match number of parameters1118 log_priors = [1119 pints.GaussianLogPrior(0, 1), pints.HalfCauchyLogPrior(0, 1)]1120 with self.assertRaisesRegex(ValueError, 'One marginal log-prior'):1121 problem.set_log_prior(log_priors)1122 # Dimensionality of joint log-pior does not match number of parameters1123 prior = pints.ComposedLogPrior(1124 pints.GaussianLogPrior(0, 1), pints.GaussianLogPrior(0, 1))1125 log_priors = [1126 prior,1127 pints.UniformLogPrior(0, 1),1128 pints.UniformLogPrior(0, 1),1129 pints.UniformLogPrior(0, 1),1130 pints.UniformLogPrior(0, 1),1131 pints.UniformLogPrior(0, 1),1132 pints.UniformLogPrior(0, 1)]1133 with self.assertRaisesRegex(ValueError, 'The joint log-prior'):1134 problem.set_log_prior(log_priors)1135 # Specified parameter names do not match the model parameters1136 params = ['wrong', 'params']1137 log_priors = [pints.HalfCauchyLogPrior(0, 1)] * 71138 with self.assertRaisesRegex(ValueError, 'The specified parameter'):1139 problem.set_log_prior(log_priors, params)1140 def test_set_population_model(self):1141 # Test case I: PD model1142 problem = copy.deepcopy(self.pd_problem)1143 problem.set_data(self.data, {'myokit.tumour_volume': 'Tumour volume'})1144 pop_models = [1145 chi.PooledModel(),1146 chi.PooledModel(),1147 chi.HeterogeneousModel(),1148 chi.PooledModel(),1149 chi.PooledModel(),1150 chi.PooledModel(),1151 chi.LogNormalModel()]1152 # Test case I.1: Don't specify order1153 problem.set_population_model(pop_models)1154 self.assertEqual(problem.get_n_parameters(), 13)1155 param_names = problem.get_parameter_names()1156 self.assertEqual(len(param_names), 13)1157 self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')1158 self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')1159 self.assertEqual(param_names[2], 'ID 0: myokit.kappa')1160 self.assertEqual(param_names[3], 'ID 1: myokit.kappa')1161 self.assertEqual(param_names[4], 'ID 2: myokit.kappa')1162 self.assertEqual(param_names[5], 'Pooled myokit.lambda_0')1163 self.assertEqual(param_names[6], 'Pooled myokit.lambda_1')1164 self.assertEqual(param_names[7], 'Pooled Sigma base')1165 self.assertEqual(param_names[8], 'ID 0: Sigma rel.')1166 self.assertEqual(param_names[9], 'ID 1: Sigma rel.')1167 self.assertEqual(param_names[10], 'ID 2: Sigma rel.')1168 self.assertEqual(param_names[11], 'Mean log Sigma rel.')1169 self.assertEqual(param_names[12], 'Std. log Sigma rel.')1170 # Test case I.2: Specify order1171 parameter_names = [1172 'Sigma base',1173 'myokit.drug_concentration',1174 'myokit.lambda_1',1175 'myokit.kappa',1176 'myokit.tumour_volume',1177 'Sigma rel.',1178 'myokit.lambda_0']1179 problem.set_population_model(pop_models, parameter_names)1180 self.assertEqual(problem.get_n_parameters(), 13)1181 param_names = problem.get_parameter_names()1182 self.assertEqual(len(param_names), 13)1183 self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')1184 self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')1185 self.assertEqual(param_names[2], 'Pooled myokit.kappa')1186 self.assertEqual(param_names[3], 'ID 0: myokit.lambda_0')1187 self.assertEqual(param_names[4], 'ID 1: myokit.lambda_0')1188 self.assertEqual(param_names[5], 'ID 2: myokit.lambda_0')1189 self.assertEqual(param_names[6], 'Mean log myokit.lambda_0')1190 self.assertEqual(param_names[7], 'Std. log myokit.lambda_0')1191 self.assertEqual(param_names[8], 'ID 0: myokit.lambda_1')1192 self.assertEqual(param_names[9], 'ID 1: myokit.lambda_1')1193 self.assertEqual(param_names[10], 'ID 2: myokit.lambda_1')1194 self.assertEqual(param_names[11], 'Pooled Sigma base')1195 self.assertEqual(param_names[12], 'Pooled Sigma rel.')1196 # Test case I.3: With covariates1197 cov_pop_model = chi.CovariatePopulationModel(1198 chi.GaussianModel(),1199 chi.LogNormalLinearCovariateModel(n_covariates=1)1200 )1201 cov_pop_model.set_covariate_names(['Age'], True)1202 pop_models = [1203 chi.PooledModel(),1204 chi.PooledModel(),1205 chi.HeterogeneousModel(),1206 chi.PooledModel(),1207 cov_pop_model,1208 chi.PooledModel(),1209 chi.LogNormalModel()]1210 problem.set_population_model(pop_models)1211 self.assertEqual(problem.get_n_parameters(), 18)1212 param_names = problem.get_parameter_names()1213 self.assertEqual(len(param_names), 18)1214 self.assertEqual(param_names[0], 'Pooled myokit.tumour_volume')1215 self.assertEqual(param_names[1], 'Pooled myokit.drug_concentration')1216 self.assertEqual(param_names[2], 'ID 0: myokit.kappa')1217 self.assertEqual(param_names[3], 'ID 1: myokit.kappa')1218 self.assertEqual(param_names[4], 'ID 2: myokit.kappa')1219 self.assertEqual(param_names[5], 'Pooled myokit.lambda_0')1220 self.assertEqual(param_names[6], 'ID 0: myokit.lambda_1 Eta')1221 self.assertEqual(param_names[7], 'ID 1: myokit.lambda_1 Eta')1222 self.assertEqual(param_names[8], 'ID 2: myokit.lambda_1 Eta')1223 self.assertEqual(param_names[9], 'Base mean log myokit.lambda_1')1224 self.assertEqual(param_names[10], 'Std. log myokit.lambda_1')1225 self.assertEqual(param_names[11], 'Shift Age myokit.lambda_1')1226 self.assertEqual(param_names[12], 'Pooled Sigma base')1227 self.assertEqual(param_names[13], 'ID 0: Sigma rel.')1228 self.assertEqual(param_names[14], 'ID 1: Sigma rel.')1229 self.assertEqual(param_names[15], 'ID 2: Sigma rel.')1230 self.assertEqual(param_names[16], 'Mean log Sigma rel.')1231 self.assertEqual(param_names[17], 'Std. log Sigma rel.')1232 # Test case II: PKPD model1233 problem = copy.deepcopy(self.pkpd_problem)1234 problem.set_data(1235 self.data,1236 output_observable_dict={1237 'central.drug_concentration': 'IL 6',1238 'myokit.tumour_volume': 'Tumour volume'})1239 pop_models = [1240 chi.LogNormalModel(),1241 chi.LogNormalModel(),1242 chi.LogNormalModel(),1243 chi.PooledModel(),1244 chi.PooledModel(),1245 chi.HeterogeneousModel(),1246 chi.PooledModel(),1247 chi.PooledModel(),1248 chi.LogNormalModel(),1249 chi.PooledModel(),1250 chi.LogNormalModel()]1251 # Test case I.1: Don't specify order1252 problem.set_population_model(pop_models)1253 self.assertEqual(problem.get_n_parameters(), 33)1254 param_names = problem.get_parameter_names()1255 self.assertEqual(len(param_names), 33)1256 self.assertEqual(param_names[0], 'ID 0: central.drug_amount')1257 self.assertEqual(param_names[1], 'ID 1: central.drug_amount')1258 self.assertEqual(param_names[2], 'ID 2: central.drug_amount')1259 self.assertEqual(param_names[3], 'Mean log central.drug_amount')1260 self.assertEqual(param_names[4], 'Std. log central.drug_amount')1261 self.assertEqual(param_names[5], 'ID 0: myokit.tumour_volume')1262 self.assertEqual(param_names[6], 'ID 1: myokit.tumour_volume')1263 self.assertEqual(param_names[7], 'ID 2: myokit.tumour_volume')1264 self.assertEqual(param_names[8], 'Mean log myokit.tumour_volume')1265 self.assertEqual(param_names[9], 'Std. log myokit.tumour_volume')1266 self.assertEqual(param_names[10], 'ID 0: central.size')1267 self.assertEqual(param_names[11], 'ID 1: central.size')1268 self.assertEqual(param_names[12], 'ID 2: central.size')1269 self.assertEqual(param_names[13], 'Mean log central.size')1270 self.assertEqual(param_names[14], 'Std. log central.size')1271 self.assertEqual(param_names[15], 'Pooled myokit.critical_volume')1272 self.assertEqual(param_names[16], 'Pooled myokit.elimination_rate')1273 self.assertEqual(param_names[17], 'ID 0: myokit.kappa')1274 self.assertEqual(param_names[18], 'ID 1: myokit.kappa')1275 self.assertEqual(param_names[19], 'ID 2: myokit.kappa')1276 self.assertEqual(param_names[20], 'Pooled myokit.lambda')1277 self.assertEqual(1278 param_names[21], 'Pooled central.drug_concentration Sigma base')1279 self.assertEqual(1280 param_names[22], 'ID 0: central.drug_concentration Sigma rel.')1281 self.assertEqual(1282 param_names[23], 'ID 1: central.drug_concentration Sigma rel.')1283 self.assertEqual(1284 param_names[24], 'ID 2: central.drug_concentration Sigma rel.')1285 self.assertEqual(1286 param_names[25], 'Mean log central.drug_concentration Sigma rel.')1287 self.assertEqual(1288 param_names[26], 'Std. log central.drug_concentration Sigma rel.')1289 self.assertEqual(1290 param_names[27], 'Pooled myokit.tumour_volume Sigma base')1291 self.assertEqual(1292 param_names[28], 'ID 0: myokit.tumour_volume Sigma rel.')1293 self.assertEqual(1294 param_names[29], 'ID 1: myokit.tumour_volume Sigma rel.')1295 self.assertEqual(1296 param_names[30], 'ID 2: myokit.tumour_volume Sigma rel.')1297 self.assertEqual(1298 param_names[31], 'Mean log myokit.tumour_volume Sigma rel.')1299 self.assertEqual(1300 param_names[32], 'Std. log myokit.tumour_volume Sigma rel.')1301 def test_set_population_model_bad_input(self):1302 # Population models have the wrong type1303 pop_models = ['bad', 'type']1304 with self.assertRaisesRegex(TypeError, 'The population models'):1305 self.pd_problem.set_population_model(pop_models)1306 # Number of population models is not correct1307 pop_models = [chi.PooledModel()]1308 with self.assertRaisesRegex(ValueError, 'The number of population'):1309 self.pd_problem.set_population_model(pop_models)1310 # Specified parameter names do not coincide with model1311 pop_models = [chi.PooledModel()] * 71312 parameter_names = ['wrong names'] * 71313 with self.assertRaisesRegex(ValueError, 'The parameter names'):1314 self.pd_problem.set_population_model(pop_models, parameter_names)1315 # User is warned that data is reset as a result of unclear covariate1316 # mapping1317 self.pd_problem.set_data(1318 self.data,1319 output_observable_dict={1320 'central.drug_concentration': 'IL 6',1321 'myokit.tumour_volume': 'Tumour volume'})1322 cov_pop_model = chi.CovariatePopulationModel(1323 chi.GaussianModel(),1324 chi.LogNormalLinearCovariateModel(n_covariates=1)1325 )1326 pop_models = [cov_pop_model] * 71327 with self.assertWarns(UserWarning):1328 self.pd_problem.set_population_model(pop_models)1329class TestInverseProblem(unittest.TestCase):1330 """1331 Tests the chi.InverseProblem class.1332 """1333 @classmethod1334 def setUpClass(cls):1335 # Create test data1336 cls.times = [1, 2, 3, 4, 5]1337 cls.values = [1, 2, 3, 4, 5]1338 # Set up inverse problem1339 path = ModelLibrary().tumour_growth_inhibition_model_koch()1340 cls.model = chi.PharmacodynamicModel(path)1341 cls.problem = chi.InverseProblem(cls.model, cls.times, cls.values)1342 def test_bad_model_input(self):1343 model = 'bad model'1344 with self.assertRaisesRegex(ValueError, 'Model has to be an instance'):1345 chi.InverseProblem(model, self.times, self.values)1346 def test_bad_times_input(self):1347 times = [-1, 2, 3, 4, 5]1348 with self.assertRaisesRegex(ValueError, 'Times cannot be negative.'):1349 chi.InverseProblem(self.model, times, self.values)1350 times = [5, 4, 3, 2, 1]1351 with self.assertRaisesRegex(ValueError, 'Times must be increasing.'):1352 chi.InverseProblem(self.model, times, self.values)1353 def test_bad_values_input(self):1354 values = [1, 2, 3, 4, 5, 6, 7]1355 with self.assertRaisesRegex(ValueError, 'Values array must have'):1356 chi.InverseProblem(self.model, self.times, values)1357 values = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]1358 with self.assertRaisesRegex(ValueError, 'Values array must have'):1359 chi.InverseProblem(self.model, self.times, values)1360 def test_evaluate(self):1361 parameters = [0.1, 1, 1, 1, 1]1362 output = self.problem.evaluate(parameters)1363 n_times = 51364 n_outputs = 11365 self.assertEqual(output.shape, (n_times, n_outputs))1366 def test_evaluateS1(self):1367 parameters = [0.1, 1, 1, 1, 1]1368 with self.assertRaises(NotImplementedError):1369 self.problem.evaluateS1(parameters)1370 def test_n_ouputs(self):1371 self.assertEqual(self.problem.n_outputs(), 1)1372 def test_n_parameters(self):1373 self.assertEqual(self.problem.n_parameters(), 5)1374 def test_n_times(self):1375 n_times = len(self.times)1376 self.assertEqual(self.problem.n_times(), n_times)1377 def test_times(self):1378 times = self.problem.times()1379 n_times = len(times)1380 self.assertEqual(n_times, 5)1381 self.assertEqual(times[0], self.times[0])1382 self.assertEqual(times[1], self.times[1])1383 self.assertEqual(times[2], self.times[2])1384 self.assertEqual(times[3], self.times[3])1385 self.assertEqual(times[4], self.times[4])1386 def test_values(self):1387 values = self.problem.values()1388 n_times = 51389 n_outputs = 11390 self.assertEqual(values.shape, (n_times, n_outputs))1391 self.assertEqual(values[0], self.values[0])1392 self.assertEqual(values[1], self.values[1])1393 self.assertEqual(values[2], self.values[2])1394 self.assertEqual(values[3], self.values[3])1395 self.assertEqual(values[4], self.values[4])1396if __name__ == '__main__':...
track_model_train.py
Source:track_model_train.py
1from __future__ import absolute_import, division, print_function2import numpy as np3import caffe4from caffe import layers as L5from caffe import params as P6channel_mean = np.array([123.68, 116.779, 103.939], dtype=np.float32)7###############################################################################8# Helper Methods9###############################################################################10def conv_relu(bottom, nout, ks=3, stride=1, pad=1, param_names=('conv_w', 'conv_b'), bias_term=True, fix_param=False, finetune=False):11 if fix_param:12 mult = [dict(name=param_names[0], lr_mult=0, decay_mult=0), dict(name=param_names[1], lr_mult=0, decay_mult=0)]13 conv = L.Convolution(bottom, kernel_size=ks, stride=stride,14 num_output=nout, pad=pad, param=mult)15 else:16 if finetune:17 mult = [dict(name=param_names[0], lr_mult=0.1, decay_mult=1), dict(name=param_names[1], lr_mult=0.2, decay_mult=0)]18 conv = L.Convolution(bottom, kernel_size=ks, stride=stride,19 num_output=nout, pad=pad, param=mult)20 else:21 mult = [dict(name=param_names[0], lr_mult=1, decay_mult=1), dict(name=param_names[1], lr_mult=2, decay_mult=0)]22 filler = dict(type='xavier')23 conv = L.Convolution(bottom, kernel_size=ks, stride=stride,24 num_output=nout, pad=pad, bias_term=bias_term,25 param=mult, weight_filler=filler)26 return conv, L.ReLU(conv, in_place=True)27def conv(bottom, nout, ks=3, stride=1, pad=1, param_names=('conv_w', 'conv_b'), bias_term=True, fix_param=False, finetune=False):28 if fix_param:29 mult = [dict(name=param_names[0], lr_mult=0, decay_mult=0), dict(name=param_names[1], lr_mult=0, decay_mult=0)]30 conv = L.Convolution(bottom, kernel_size=ks, stride=stride,31 num_output=nout, pad=pad, param=mult)32 else:33 if finetune:34 mult = [dict(name=param_names[0], lr_mult=0.1, decay_mult=1), dict(name=param_names[1], lr_mult=0.2, decay_mult=0)]35 conv = L.Convolution(bottom, kernel_size=ks, stride=stride,36 num_output=nout, pad=pad, param=mult)37 else:38 mult = [dict(name=param_names[0], lr_mult=1, decay_mult=1), dict(name=param_names[1], lr_mult=2, decay_mult=0)]39 filler = dict(type='xavier')40 conv = L.Convolution(bottom, kernel_size=ks, stride=stride,41 num_output=nout, pad=pad, bias_term=bias_term,42 param=mult, weight_filler=filler)43 return conv44def max_pool(bottom, ks=2, stride=2):45 return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride)46################################################################################47# Model Generation48###############################################################################49def generate_scores(split, config):50 n = caffe.NetSpec()51 dataset = config.dataset52 batch_size = config.N53 mode_str = str(dict(dataset=dataset, split=split, batch_size=batch_size))54 n.image1, n.image2, n.label, n.sample_weights, n.feat_crop = L.Python(module=config.data_provider,55 layer=config.data_provider_layer,56 param_str=mode_str,57 ntop=5)58 ################################59 # the base net (VGG-16) branch 160 n.conv1_1, n.relu1_1 = conv_relu(n.image1, 64,61 param_names=('conv1_1_w', 'conv1_1_b'),62 fix_param=True,63 finetune=False)64 n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64,65 param_names=('conv1_2_w', 'conv1_2_b'),66 fix_param=True,67 finetune=False)68 n.pool1 = max_pool(n.relu1_2)69 n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128,70 param_names=('conv2_1_w', 'conv2_1_b'),71 fix_param=True,72 finetune=False)73 n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128,74 param_names=('conv2_2_w', 'conv2_2_b'),75 fix_param=True,76 finetune=False)77 n.pool2 = max_pool(n.relu2_2)78 n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256,79 param_names=('conv3_1_w', 'conv3_1_b'),80 fix_param=config.fix_vgg,81 finetune=config.finetune)82 n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256,83 param_names=('conv3_2_w', 'conv3_2_b'),84 fix_param=config.fix_vgg,85 finetune=config.finetune)86 n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256,87 param_names=('conv3_3_w', 'conv3_3_b'),88 fix_param=config.fix_vgg,89 finetune=config.finetune)90 n.pool3 = max_pool(n.relu3_3)91 # spatial L2 norm92 n.pool3_lrn = L.LRN(n.pool3, local_size=513, alpha=513, beta=0.5, k=1e-16)93 n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512,94 param_names=('conv4_1_w', 'conv4_1_b'),95 fix_param=config.fix_vgg,96 finetune=config.finetune)97 n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512,98 param_names=('conv4_2_w', 'conv4_2_b'),99 fix_param=config.fix_vgg,100 finetune=config.finetune)101 n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512,102 param_names=('conv4_3_w', 'conv4_3_b'),103 fix_param=config.fix_vgg,104 finetune=config.finetune)105 # spatial L2 norm106 n.relu4_3_lrn = L.LRN(n.relu4_3, local_size=1025, alpha=1025, beta=0.5, k=1e-16)107 #n.pool4 = max_pool(n.relu4_3)108 #n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512,109 # param_names=('conv5_1_w', 'conv5_1_b'),110 # fix_param=config.fix_vgg,111 # finetune=config.finetune)112 #n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512,113 # param_names=('conv5_2_w', 'conv5_2_b'),114 # fix_param=config.fix_vgg,115 # finetune=config.finetune)116 #n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512,117 # param_names=('conv5_3_w', 'conv5_3_b'),118 # fix_param=config.fix_vgg,119 # finetune=config.finetune)120 # upsampling feature map121 #n.relu5_3_upsampling = L.Deconvolution(n.relu5_3,122 # convolution_param=dict(num_output=512,123 # group=512,124 # kernel_size=4,125 # stride=2,126 # pad=1,127 # bias_term=False,128 # weight_filler=dict(type='bilinear')),129 # param=[dict(lr_mult=0, decay_mult=0)])130 # spatial L2 norm131 #n.relu5_3_lrn = L.LRN(n.relu5_3_upsampling, local_size=1025, alpha=1025, beta=0.5, k=1e-16)132 # concat all skip features133 #n.feat_all1 = n.relu4_3_lrn134 n.feat_all1 = L.Concat(n.pool3_lrn, n.relu4_3_lrn, concat_param=dict(axis=1))135 #n.feat_all1 = L.Concat(n.pool3_lrn, n.relu4_3_lrn, n.relu5_3_lrn, concat_param=dict(axis=1))136 n.feat_all1_crop = L.Crop(n.feat_all1, n.feat_crop, crop_param=dict(axis=2, offset=[config.query_featmap_H//3, config.query_featmap_W//3]))137 138 ################################139 # the base net (VGG-16) branch 2140 n.conv1_1_p, n.relu1_1_p = conv_relu(n.image2, 64,141 param_names=('conv1_1_w', 'conv1_1_b'),142 fix_param=True,143 finetune=False)144 n.conv1_2_p, n.relu1_2_p = conv_relu(n.relu1_1_p, 64,145 param_names=('conv1_2_w', 'conv1_2_b'),146 fix_param=True,147 finetune=False)148 n.pool1_p = max_pool(n.relu1_2_p)149 n.conv2_1_p, n.relu2_1_p = conv_relu(n.pool1_p, 128,150 param_names=('conv2_1_w', 'conv2_1_b'),151 fix_param=True,152 finetune=False)153 n.conv2_2_p, n.relu2_2_p = conv_relu(n.relu2_1_p, 128,154 param_names=('conv2_2_w', 'conv2_2_b'),155 fix_param=True,156 finetune=False)157 n.pool2_p = max_pool(n.relu2_2_p)158 n.conv3_1_p, n.relu3_1_p = conv_relu(n.pool2_p, 256,159 param_names=('conv3_1_w', 'conv3_1_b'),160 fix_param=config.fix_vgg,161 finetune=config.finetune)162 n.conv3_2_p, n.relu3_2_p = conv_relu(n.relu3_1_p, 256,163 param_names=('conv3_2_w', 'conv3_2_b'),164 fix_param=config.fix_vgg,165 finetune=config.finetune)166 n.conv3_3_p, n.relu3_3_p = conv_relu(n.relu3_2_p, 256,167 param_names=('conv3_3_w', 'conv3_3_b'),168 fix_param=config.fix_vgg,169 finetune=config.finetune)170 n.pool3_p = max_pool(n.relu3_3_p)171 # spatial L2 norm172 n.pool3_lrn_p = L.LRN(n.pool3_p, local_size=513, alpha=513, beta=0.5, k=1e-16)173 n.conv4_1_p, n.relu4_1_p = conv_relu(n.pool3_p, 512,174 param_names=('conv4_1_w', 'conv4_1_b'),175 fix_param=config.fix_vgg,176 finetune=config.finetune)177 n.conv4_2_p, n.relu4_2_p = conv_relu(n.relu4_1_p, 512,178 param_names=('conv4_2_w', 'conv4_2_b'),179 fix_param=config.fix_vgg,180 finetune=config.finetune)181 n.conv4_3_p, n.relu4_3_p = conv_relu(n.relu4_2_p, 512,182 param_names=('conv4_3_w', 'conv4_3_b'),183 fix_param=config.fix_vgg,184 finetune=config.finetune)185 # spatial L2 norm186 n.relu4_3_lrn_p = L.LRN(n.relu4_3_p, local_size=1025, alpha=1025, beta=0.5, k=1e-16)187 #n.pool4_p = max_pool(n.relu4_3_p)188 #n.conv5_1_p, n.relu5_1_p = conv_relu(n.pool4_p, 512,189 # param_names=('conv5_1_w', 'conv5_1_b'),190 # fix_param=config.fix_vgg,191 # finetune=config.finetune)192 #n.conv5_2_p, n.relu5_2_p = conv_relu(n.relu5_1_p, 512,193 # param_names=('conv5_2_w', 'conv5_2_b'),194 # fix_param=config.fix_vgg,195 # finetune=config.finetune)196 #n.conv5_3_p, n.relu5_3_p = conv_relu(n.relu5_2_p, 512,197 # param_names=('conv5_3_w', 'conv5_3_b'),198 # fix_param=config.fix_vgg,199 # finetune=config.finetune)200 # upsampling feature map201 #n.relu5_3_upsampling_p = L.Deconvolution(n.relu5_3_p,202 # convolution_param=dict(num_output=512,203 # group=512,204 # kernel_size=4,205 # stride=2,206 # pad=1,207 # bias_term=False,208 # weight_filler=dict(type='bilinear')),209 # param=[dict(lr_mult=0, decay_mult=0)])210 # spatial L2 norm211 #n.relu5_3_lrn_p = L.LRN(n.relu5_3_upsampling_p, local_size=1025, alpha=1025, beta=0.5, k=1e-16)212 # concat all skip features213 #n.feat_all2 = n.relu4_3_lrn_p214 n.feat_all2 = L.Concat(n.pool3_lrn_p, n.relu4_3_lrn_p, concat_param=dict(axis=1))215 #n.feat_all2 = L.Concat(n.pool3_lrn_p, n.relu4_3_lrn_p, n.relu5_3_lrn_p, concat_param=dict(axis=1))216 # Dyn conv layer217 n.fcn_scores = L.DynamicConvolution(n.feat_all2, n.feat_all1_crop,218 convolution_param=dict(num_output=1,219 kernel_size=11,220 stride=1,221 pad=5,222 bias_term=False))223 return n.to_proto()224def generate_model(split, config):225 n = caffe.NetSpec()226 dataset = config.dataset227 batch_size = config.N228 mode_str = str(dict(dataset=dataset, split=split, batch_size=batch_size))229 n.image1, n.image2, n.label, n.sample_weights, n.feat_crop = L.Python(module=config.data_provider,230 layer=config.data_provider_layer,231 param_str=mode_str,232 ntop=5)233 ################################234 # the base net (VGG-16) branch 1235 n.conv1_1, n.relu1_1 = conv_relu(n.image1, 64,236 param_names=('conv1_1_w', 'conv1_1_b'),237 fix_param=True,238 finetune=False)239 n.conv1_2, n.relu1_2 = conv_relu(n.relu1_1, 64,240 param_names=('conv1_2_w', 'conv1_2_b'),241 fix_param=True,242 finetune=False)243 n.pool1 = max_pool(n.relu1_2)244 n.conv2_1, n.relu2_1 = conv_relu(n.pool1, 128,245 param_names=('conv2_1_w', 'conv2_1_b'),246 fix_param=True,247 finetune=False)248 n.conv2_2, n.relu2_2 = conv_relu(n.relu2_1, 128,249 param_names=('conv2_2_w', 'conv2_2_b'),250 fix_param=True,251 finetune=False)252 n.pool2 = max_pool(n.relu2_2)253 n.conv3_1, n.relu3_1 = conv_relu(n.pool2, 256,254 param_names=('conv3_1_w', 'conv3_1_b'),255 fix_param=config.fix_vgg,256 finetune=config.finetune)257 n.conv3_2, n.relu3_2 = conv_relu(n.relu3_1, 256,258 param_names=('conv3_2_w', 'conv3_2_b'),259 fix_param=config.fix_vgg,260 finetune=config.finetune)261 n.conv3_3, n.relu3_3 = conv_relu(n.relu3_2, 256,262 param_names=('conv3_3_w', 'conv3_3_b'),263 fix_param=config.fix_vgg,264 finetune=config.finetune)265 n.pool3 = max_pool(n.relu3_3)266 # spatial L2 norm267 n.pool3_lrn = L.LRN(n.pool3, local_size=513, alpha=513, beta=0.5, k=1e-16)268 n.conv4_1, n.relu4_1 = conv_relu(n.pool3, 512,269 param_names=('conv4_1_w', 'conv4_1_b'),270 fix_param=config.fix_vgg,271 finetune=config.finetune)272 n.conv4_2, n.relu4_2 = conv_relu(n.relu4_1, 512,273 param_names=('conv4_2_w', 'conv4_2_b'),274 fix_param=config.fix_vgg,275 finetune=config.finetune)276 n.conv4_3, n.relu4_3 = conv_relu(n.relu4_2, 512,277 param_names=('conv4_3_w', 'conv4_3_b'),278 fix_param=config.fix_vgg,279 finetune=config.finetune)280 # spatial L2 norm281 n.relu4_3_lrn = L.LRN(n.relu4_3, local_size=1025, alpha=1025, beta=0.5, k=1e-16)282 #n.pool4 = max_pool(n.relu4_3)283 #n.conv5_1, n.relu5_1 = conv_relu(n.pool4, 512,284 # param_names=('conv5_1_w', 'conv5_1_b'),285 # fix_param=config.fix_vgg,286 # finetune=config.finetune)287 #n.conv5_2, n.relu5_2 = conv_relu(n.relu5_1, 512,288 # param_names=('conv5_2_w', 'conv5_2_b'),289 # fix_param=config.fix_vgg,290 # finetune=config.finetune)291 #n.conv5_3, n.relu5_3 = conv_relu(n.relu5_2, 512,292 # param_names=('conv5_3_w', 'conv5_3_b'),293 # fix_param=config.fix_vgg,294 # finetune=config.finetune)295 # upsampling feature map296 #n.relu5_3_upsampling = L.Deconvolution(n.relu5_3,297 # convolution_param=dict(num_output=512,298 # group=512,299 # kernel_size=4,300 # stride=2,301 # pad=1,302 # bias_term=False,303 # weight_filler=dict(type='bilinear')),304 # param=[dict(lr_mult=0, decay_mult=0)])305 # spatial L2 norm306 #n.relu5_3_lrn = L.LRN(n.relu5_3_upsampling, local_size=1025, alpha=1025, beta=0.5, k=1e-16)307 # concat all skip features308 #n.feat_all1 = n.relu4_3_lrn309 n.feat_all1 = L.Concat(n.pool3_lrn, n.relu4_3_lrn, concat_param=dict(axis=1))310 #n.feat_all1 = L.Concat(n.pool3_lrn, n.relu4_3_lrn, n.relu5_3_lrn, concat_param=dict(axis=1))311 n.feat_all1_crop = L.Crop(n.feat_all1, n.feat_crop, crop_param=dict(axis=2, offset=[config.query_featmap_H//3, config.query_featmap_W//3]))312 ################################313 # the base net (VGG-16) branch 2314 n.conv1_1_p, n.relu1_1_p = conv_relu(n.image2, 64,315 param_names=('conv1_1_w', 'conv1_1_b'),316 fix_param=True,317 finetune=False)318 n.conv1_2_p, n.relu1_2_p = conv_relu(n.relu1_1_p, 64,319 param_names=('conv1_2_w', 'conv1_2_b'),320 fix_param=True,321 finetune=False)322 n.pool1_p = max_pool(n.relu1_2_p)323 n.conv2_1_p, n.relu2_1_p = conv_relu(n.pool1_p, 128,324 param_names=('conv2_1_w', 'conv2_1_b'),325 fix_param=True,326 finetune=False)327 n.conv2_2_p, n.relu2_2_p = conv_relu(n.relu2_1_p, 128,328 param_names=('conv2_2_w', 'conv2_2_b'),329 fix_param=True,330 finetune=False)331 n.pool2_p = max_pool(n.relu2_2_p)332 n.conv3_1_p, n.relu3_1_p = conv_relu(n.pool2_p, 256,333 param_names=('conv3_1_w', 'conv3_1_b'),334 fix_param=config.fix_vgg,335 finetune=config.finetune)336 n.conv3_2_p, n.relu3_2_p = conv_relu(n.relu3_1_p, 256,337 param_names=('conv3_2_w', 'conv3_2_b'),338 fix_param=config.fix_vgg,339 finetune=config.finetune)340 n.conv3_3_p, n.relu3_3_p = conv_relu(n.relu3_2_p, 256,341 param_names=('conv3_3_w', 'conv3_3_b'),342 fix_param=config.fix_vgg,343 finetune=config.finetune)344 n.pool3_p = max_pool(n.relu3_3_p)345 # spatial L2 norm346 n.pool3_lrn_p = L.LRN(n.pool3_p, local_size=513, alpha=513, beta=0.5, k=1e-16)347 n.conv4_1_p, n.relu4_1_p = conv_relu(n.pool3_p, 512,348 param_names=('conv4_1_w', 'conv4_1_b'),349 fix_param=config.fix_vgg,350 finetune=config.finetune)351 n.conv4_2_p, n.relu4_2_p = conv_relu(n.relu4_1_p, 512,352 param_names=('conv4_2_w', 'conv4_2_b'),353 fix_param=config.fix_vgg,354 finetune=config.finetune)355 n.conv4_3_p, n.relu4_3_p = conv_relu(n.relu4_2_p, 512,356 param_names=('conv4_3_w', 'conv4_3_b'),357 fix_param=config.fix_vgg,358 finetune=config.finetune)359 # spatial L2 norm360 n.relu4_3_lrn_p = L.LRN(n.relu4_3_p, local_size=1025, alpha=1025, beta=0.5, k=1e-16)361 #n.pool4_p = max_pool(n.relu4_3_p)362 #n.conv5_1_p, n.relu5_1_p = conv_relu(n.pool4_p, 512,363 # param_names=('conv5_1_w', 'conv5_1_b'),364 # fix_param=config.fix_vgg,365 # finetune=config.finetune)366 #n.conv5_2_p, n.relu5_2_p = conv_relu(n.relu5_1_p, 512,367 # param_names=('conv5_2_w', 'conv5_2_b'),368 # fix_param=config.fix_vgg,369 # finetune=config.finetune)370 #n.conv5_3_p, n.relu5_3_p = conv_relu(n.relu5_2_p, 512,371 # param_names=('conv5_3_w', 'conv5_3_b'),372 # fix_param=config.fix_vgg,373 # finetune=config.finetune)374 # upsampling feature map375 #n.relu5_3_upsampling_p = L.Deconvolution(n.relu5_3_p,376 # convolution_param=dict(num_output=512,377 # group=512,378 # kernel_size=4,379 # stride=2,380 # pad=1,381 # bias_term=False,382 # weight_filler=dict(type='bilinear')),383 # param=[dict(lr_mult=0, decay_mult=0)])384 # spatial L2 norm385 #n.relu5_3_lrn_p = L.LRN(n.relu5_3_upsampling_p, local_size=1025, alpha=1025, beta=0.5, k=1e-16)386 # concat all skip features387 #n.feat_all2 = n.relu4_3_lrn_p388 n.feat_all2 = L.Concat(n.pool3_lrn_p, n.relu4_3_lrn_p, concat_param=dict(axis=1))389 #n.feat_all2 = L.Concat(n.pool3_lrn_p, n.relu4_3_lrn_p, n.relu5_3_lrn_p, concat_param=dict(axis=1))390 # Dyn conv layer391 n.fcn_scores = L.DynamicConvolution(n.feat_all2, n.feat_all1_crop,392 convolution_param=dict(num_output=1,393 kernel_size=11,394 stride=1,395 pad=5,396 bias_term=False))397 398 # scale scores with zero mean 0.01196 -> 0.02677399 n.fcn_scaled_scores = L.Power(n.fcn_scores, power_param=dict(scale=0.01196,400 shift=-1.0,401 power=1))402 # Loss Layer403 n.loss = L.WeightedSigmoidCrossEntropyLoss(n.fcn_scaled_scores, n.label, n.sample_weights)...
translations.py
Source:translations.py
1"""2 Flowblade Movie Editor is a nonlinear video editor.3 Copyright 2012 Janne Liljeblad.4 This file is part of Flowblade Movie Editor <http://code.google.com/p/flowblade>.5 Flowblade Movie Editor is free software: you can redistribute it and/or modify6 it under the terms of the GNU General Public License as published by7 the Free Software Foundation, either version 3 of the License, or8 (at your option) any later version.9 Flowblade Movie Editor is distributed in the hope that it will be useful,10 but WITHOUT ANY WARRANTY; without even the implied warranty of11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the12 GNU General Public License for more details.13 You should have received a copy of the GNU General Public License14 along with Flowblade Movie Editor. If not, see <http://www.gnu.org/licenses/>.15"""16import gettext17import locale18import os19import respaths20APP_NAME = "Flowblade"21lang = None22filter_groups = {}23filter_names = {}24param_names = {}25combo_options = {}26def init_languages():27 langs = []28 lc, encoding = locale.getdefaultlocale()29 if (lc):30 langs = [lc]31 print "Locale:", lc32 language = os.environ.get('LANGUAGE', None)33 if (language):34 langs += language.split(":")35 gettext.bindtextdomain(APP_NAME, respaths.LOCALE_PATH)36 gettext.textdomain(APP_NAME)37 # Get the language to use38 global lang39 #lang = gettext.translation(APP_NAME, respaths.LOCALE_PATH, languages=["fi"], fallback=True) # Testing, comment out for production40 lang = gettext.translation(APP_NAME, respaths.LOCALE_PATH, languages=langs, fallback=True)41 lang.install(APP_NAME) # makes _() a build-in available in all modules without imports42def get_filter_name(f_name):43 try:44 return filter_names[f_name]45 except KeyError:46 return f_name47def get_filter_group_name(group_name):48 try:49 return filter_groups[group_name]50 except:51 return group_name52def get_param_name(name):53 try:54 return param_names[name]55 except KeyError:56 return name57def get_combo_option(c_opt):58 try:59 return combo_options[c_opt]60 except KeyError:61 return c_opt62 63def load_filters_translations():64 # filter group names65 global filter_groups66 filter_groups["Color"] = _("Color")67 filter_groups["Color Effect"] = _("Color Effect")68 filter_groups["Audio"] = _("Audio")69 filter_groups["Audio Filter"] = _("Audio Filter")70 filter_groups["Blur"] = _("Blur")71 filter_groups["Distort"] = _("Distort")72 filter_groups["Alpha"] = _("Alpha")73 filter_groups["Movement"] = _("Movement")74 filter_groups["Transform"] = _("Transform")75 filter_groups["Edge"] = _("Edge")76 filter_groups["Fix"] = _("Fix")77 filter_groups["Artistic"] = _("Artistic")78 # filter names79 global filter_names80 filter_names["Alpha Gradient"] = _("Alpha Gradient")81 filter_names["Crop"] = _("Crop")82 filter_names["Alpha Shape"]= _("Alpha Shape")83 84 filter_names["Volume"]= _("Volume")85 filter_names["Pan"]= _("Pan")86 filter_names["Pan Keyframed"]= _("Pan Keyframed")87 filter_names["Mono to Stereo"]= _("Mono to Stereo")88 filter_names["Swap Channels"]= _("Swap Channels")89 filter_names["Pitchshifter"]= _("Pitchshifter")90 filter_names["Distort - Barry's Satan"]= _("Distort - Barry's Satan")91 filter_names["Frequency Shift - Bode/Moog"]= _("Frequency Shift - Bode/Moog")92 filter_names["Equalize - DJ 3-band"]= _("Equalize - DJ 3-band")93 filter_names["Flanger - DJ"]= _("Flanger - DJ")94 filter_names["Declipper"]= _("Declipper")95 filter_names["Delayorama"]= _("Delayorama")96 filter_names["Distort - Diode Processor"]= _("Distort - Diode Processor")97 filter_names["Distort - Foldover"]= _("Distort - Foldover")98 filter_names["Highpass - Butterworth"]= _("Highpass - Butterworth")99 filter_names["Lowpass - Butterworth"]= _("Lowpass - Butterworth")100 filter_names["GSM Simulator"]= _("GSM Simulator")101 filter_names["Reverb - GVerb"]= _("Reverb - GVerb")102 filter_names["Noise Gate"]= _("Noise Gate")103 filter_names["Bandpass"]= _("Bandpass")104 filter_names["Pitchscaler - High Quality"]= _("Pitchscaler - High Quality")105 filter_names["Equalize - Multiband"]= _("Equalize - Multiband")106 filter_names["Reverb - Plate"]= _("Reverb - Plate")107 filter_names["Distort - Pointer cast"]= _("Distort - Pointer cast")108 filter_names["Rate Shifter"]= _("Rate Shifter")109 filter_names["Signal Shifter"]= _("Signal Shifter")110 filter_names["Distort - Sinus Wavewrap"]= _("Distort - Sinus Wavewrap")111 filter_names["Vinyl Effect"]= _("Vinyl Effect")112 filter_names["Chorus - Multivoice"]= _("Chorus - Multivoice")113 filter_names["Charcoal"]= _("Charcoal")114 filter_names["Glow"]= _("Glow")115 filter_names["Old Film"]= _("Old Film")116 filter_names["Scanlines"]= _("Scanlines")117 filter_names["Cartoon"]= _("Cartoon")118 119 filter_names["Pixelize"]= _("Pixelize")120 filter_names["Blur"]= _("Blur")121 filter_names["Grain"]= _("Grain")122 123 filter_names["Grayscale"]= _("Grayscale")124 filter_names["Contrast"]= _("Contrast")125 filter_names["Saturation"]= _("Saturation")126 filter_names["Invert"]= _("Invert")127 filter_names["Hue"]= _("Hue")128 filter_names["Brightness"]= _("Brightness")129 filter_names["Sepia"]= _("Sepia")130 filter_names["Tint"]= _("Tint")131 filter_names["White Balance"]= _("White Balance")132 filter_names["Levels"]= _("Levels")133 filter_names["Color Clustering"]= _("Color Clustering")134 filter_names["Chroma Hold"]= _("Chroma Hold")135 filter_names["Three Layer"]= _("Three Layer")136 filter_names["Threshold0r"]= _("Threshold0r")137 filter_names["Technicolor"]= _("Technicolor")138 filter_names["Primaries"]= _("Primaries")139 filter_names["Color Distance"]= _("Color Distance")140 filter_names["Threshold"]= _("Threshold")141 filter_names["Waves"]= _("Waves")142 filter_names["Lens Correction"]= _("Lens Correction")143 filter_names["Flip"]= _("Flip")144 filter_names["Mirror"]= _("Mirror")145 filter_names["V Sync"]= _("V Sync")146 filter_names["Edge Glow"]= _("Edge Glow")147 filter_names["Sobel"]= _("Sobel")148 filter_names["Denoise"]= _("Denoise")149 filter_names["Sharpness"]= _("Sharpness")150 filter_names["Letterbox"]= _("Letterbox")151 filter_names["Baltan"]= _("Baltan")152 filter_names["Vertigo"]= _("Vertigo")153 filter_names["Nervous"]= _("Nervous")154 filter_names["Freeze"]= _("Freeze")155 filter_names["Rotate"]= _("Rotate")156 filter_names["Shear"]= _("Shear")157 filter_names["Translate"]= _("Translate")158 # 0.8 added159 filter_names["Color Select"]= _("Color Select")160 filter_names["Alpha Modify"]= _("Alpha Modify")161 filter_names["Spill Supress"]= _("Spill Supress")162 filter_names["RGB Noise"]= _("RGB Noise")163 filter_names["Box Blur"]= _("Box Blur")164 filter_names["IRR Blur"]= _("IRR Blur")165 filter_names["Color Halftone"]= _("Color Halftone")166 filter_names["Dither"]= _("Dither")167 filter_names["Vignette"]= _("Vignette")168 filter_names["Emboss"]= _("Emboss")169 filter_names["3 Point Balance"]= _("3 Point Balance")170 filter_names["Colorize"]= _("Colorize")171 filter_names["Brightness Keyframed"]= _("Brightness Keyframed")172 filter_names["RGB Adjustment"]= _("RGB Adjustment")173 filter_names["Color Tap"]= _("Color Tap")174 filter_names["Posterize"]= _("Posterize")175 filter_names["Soft Glow"]= _("Soft Glow")176 filter_names["Newspaper"]= _("Newspaper")177 # 0.16 added178 filter_names["Luma Key"] = _("Luma Key")179 filter_names["Chroma Key"] = _("Chroma Key")180 filter_names["Affine"] = _("Affine")181 filter_names["Color Adjustment"] = _("Color Adjustment")182 filter_names["Color Grading"] = _("Color Grading")183 filter_names["Curves"] = _("Curves")184 filter_names["Lift Gain Gamma"] = _("Lift Gain Gamma")185 filter_names["Image Grid"] = _("Image Grid")186 187 # 0.18188 filter_names["Color Lift Gain Gamma"] = _("Color Lift Gain Gamma")189 190 # param names191 global param_names192 # param names for filters193 param_names["Position"] = _("Position")194 param_names["Grad width"] = _("Grad width")195 param_names["Tilt"] = _("Tilt")196 param_names["Min"] = _("Min")197 param_names["Max"] = _("Max")198 param_names["Left"] = _("Left")199 param_names["Right"] = _("Right")200 param_names["Top"] = _("Top")201 param_names["Bottom"] = _("Bottom")202 param_names["Shape"] = _("Shape")203 param_names["Pos X"] = _("Pos X")204 param_names["Pos Y"] = _("Pos Y")205 param_names["Size X"] = _("Size X")206 param_names["Size Y"] = _("Size Y")207 param_names["Tilt"] = _("Tilt")208 param_names["Trans. Width"] = _("Trans. Width")209 param_names["Volume"] = _("Volume")210 param_names["Left/Right"] = _("Left/Right")211 param_names["Left/Right"] = _("Left/Right")212 param_names["Dry/Wet"] = _("Dry/Wet")213 param_names["Pitch Shift"] = _("Pitch Shift")214 param_names["Buffer Size"] = _("Buffer Size")215 param_names["Dry/Wet"] = _("Dry/Wet")216 param_names["Decay Time(samples)"] = _("Decay Time(samples)")217 param_names["Knee Point(dB)"] = _("Knee Point(dB)")218 param_names["Dry/Wet"] = _("Dry/Wet")219 param_names["Frequency shift"] = _("Frequency shift")220 param_names["Dry/Wet"] = _("Dry/Wet")221 param_names["Low Gain(dB)"] = _("Low Gain(dB)")222 param_names["Mid Gain(dB)"] = _("Mid Gain(dB)")223 param_names["High Gain(dB)"] = _("High Gain(dB)")224 param_names["Dry/Wet"] = _("Dry/Wet")225 param_names["Oscillation period(s)"] = _("Oscillation period(s)")226 param_names["Oscillation depth(ms)"] = _("Oscillation depth(ms)")227 param_names["Feedback%"] = _("Feedback%")228 param_names["Dry/Wet"] = _("Dry/Wet")229 param_names["Dry/Wet"] = _("Dry/Wet")230 param_names["Random seed"] = _("Random seed")231 param_names["Input Gain(dB)"] = _("Input Gain(dB)")232 param_names["Feedback(%)"] = _("Feedback(%)")233 param_names["Number of taps"] = _("Number of taps")234 param_names["First Delay(s)"] = _("First Delay(s)")235 param_names["Delay Range(s)"] = _("Delay Range(s)")236 param_names["Delay Change"] = _("Delay Change")237 param_names["Delay Random(%)"] = _("Delay Random(%)")238 param_names["Amplitude Change"] = _("Amplitude Change")239 param_names["Amplitude Random(%)"] = _("Amplitude Random(%)")240 param_names["Dry/Wet"] = _("Dry/Wet")241 param_names["Amount"] = _("Amount")242 param_names["Dry/Wet"] = _("Dry/Wet")243 param_names["Drive"] = _("Drive")244 param_names["Skew"] = _("Skew")245 param_names["Dry/Wet"] = _("Dry/Wet")246 param_names["Cutoff Frequency(Hz)"] = _("Cutoff Frequency(Hz)")247 param_names["Resonance"] = _("Resonance")248 param_names["Dry/Wet"] = _("Dry/Wet")249 param_names["Cutoff Frequency(Hz)"] = _("Cutoff Frequency(Hz)")250 param_names["Resonance"] = _("Resonance")251 param_names["Dry/Wet"] = _("Dry/Wet")252 param_names["Passes"] = _("Passes")253 param_names["Error Rate"] = _("Error Rate")254 param_names["Dry/Wet"] = _("Dry/Wet")255 param_names["Roomsize"] = _("Roomsize")256 param_names["Reverb time(s)"] = _("Reverb time(s)")257 param_names["Damping"] = _("Damping")258 param_names["Input bandwith"] = _("Input bandwith")259 param_names["Dry signal level(dB)"] = _("Dry signal level(dB)")260 param_names["Early reflection level(dB)"] = _("Early reflection level(dB)")261 param_names["Tail level(dB)"] = _("Tail level(dB)")262 param_names["Dry/Wet"] = _("Dry/Wet")263 param_names["LF keyfilter(Hz)"] = _("LF keyfilter(Hz)")264 param_names["HF keyfilter(Hz)"] = _("HF keyfilter(Hz)")265 param_names["Threshold(dB)"] = _("Threshold(dB)")266 param_names["Attack(ms)"] = _("Attack(ms)")267 param_names["Hold(ms)"] = _("Hold(ms)")268 param_names["Decay(ms)"] = _("Decay(ms)")269 param_names["Range(dB)"] = _("Range(dB)")270 param_names["Dry/Wet"] = _("Dry/Wet")271 param_names["Center Frequency(Hz)"] = _("Center Frequency(Hz)")272 param_names["Bandwidth(Hz)"] = _("Bandwidth(Hz)")273 param_names["Stages"] = _("Stages")274 param_names["Dry/Wet"] = _("Dry/Wet")275 param_names["Pitch-coefficient"] = _("Pitch-coefficient")276 param_names["Dry/Wet"] = _("Dry/Wet")277 param_names["50Hz gain"] = _("50Hz gain")278 param_names["100Hz gain"] = _("100Hz gain")279 param_names["156Hz gain"] = _("156Hz gain")280 param_names["220Hz gain"] = _("220Hz gain")281 param_names["311Hz gain"] = _("311Hz gain")282 param_names["440Hz gain"] = _("440Hz gain")283 param_names["622Hz gain"] = _("622Hz gain")284 param_names["880Hz gain"] = _("880Hz gain")285 param_names["1250Hz gain"] = _("1250Hz gain")286 param_names["1750Hz gain"] = _("1750Hz gain")287 param_names["2500Hz gain"] = _("2500Hz gain")288 param_names["3500Hz gain"] = _("3500Hz gain")289 param_names["5000Hz gain"] = _("5000Hz gain")290 param_names["100000Hz gain"] = _("100000Hz gain")291 param_names["200000Hz gain"] = _("200000Hz gain")292 param_names["Dry/Wet"] = _("Dry/Wet")293 param_names["Reverb time"] = _("Reverb time")294 param_names["Damping"] = _("Damping")295 param_names["Dry/Wet mix"] = _("Dry/Wet mix")296 param_names["Dry/Wet"] = _("Dry/Wet")297 param_names["Effect cutoff(Hz)"] = _("Effect cutoff(Hz)")298 param_names["Dry/Wet mix"] = _("Dry/Wet mix")299 param_names["Dry/Wet"] = _("Dry/Wet")300 param_names["Rate"] = _("Rate")301 param_names["Dry/Wet"] = _("Dry/Wet")302 param_names["Sift"] = _("Sift")303 param_names["Dry/Wet"] = _("Dry/Wet")304 param_names["Amount"] = _("Amount")305 param_names["Dry/Wet"] = _("Dry/Wet")306 param_names["Year"] = _("Year")307 param_names["RPM"] = _("RPM")308 param_names["Surface warping"] = _("Surface warping")309 param_names["Cracle"] = _("Cracle")310 param_names["Wear"] = _("Wear")311 param_names["Dry/Wet"] = _("Dry/Wet")312 param_names["Number of voices"] = _("Number of voices")313 param_names["Delay base(ms)"] = _("Delay base(ms)")314 param_names["Voice separation(ms)"] = _("Voice separation(ms)")315 param_names["Detune(%)"] = _("Detune(%)")316 param_names["Oscillation frequency(Hz)"] = _("Oscillation frequency(Hz)")317 param_names["Output attenuation(dB)"] = _("Output attenuation(dB)")318 param_names["Dry/Wet"] = _("Dry/Wet")319 param_names["X Scatter"] = _("X Scatter")320 param_names["Y Scatter"] = _("Y Scatter")321 param_names["Scale"] = _("Scale")322 param_names["Mix"] = _("Mix")323 param_names["Invert"] = _("Invert")324 param_names["Blur"] = _("Blur")325 param_names["Delta"] = _("Delta")326 param_names["Duration"] = _("Duration")327 param_names["Bright. up"] = _("Bright. up")328 param_names["Bright. down"] = _("Bright. down")329 param_names["Bright. dur."] = _("Bright. dur.")330 param_names["Develop up"] = _("Develop up")331 param_names["Develop down"] = _("Develop down")332 param_names["Develop dur."] = _("Develop dur.")333 param_names["Triplevel"] = _("Triplevel")334 param_names["Difference Space"] = _("Difference Space")335 param_names["Block width"] = _("Block width")336 param_names["Block height"] = _("Block height")337 param_names["Size"] = _("Size")338 param_names["Noise"] = _("Noise")339 param_names["Contrast"] = _("Contrast")340 param_names["Brightness"] = _("Brightness")341 param_names["Contrast"] = _("Contrast")342 param_names["Saturation"] = _("Saturation")343 param_names["Hue"] = _("Hue")344 param_names["Brightness"] = _("Brightness")345 param_names["Brightness"] = _("Brightness")346 param_names["U"] = _("U")347 param_names["V"] = _("V")348 param_names["Black"] = _("Black")349 param_names["White"] = _("White")350 param_names["Amount"] = _("Amount")351 param_names["Neutral Color"] = _("Neutral Color")352 param_names["Input"] = _("Input")353 param_names["Input"] = _("Input")354 param_names["Gamma"] = _("Gamma")355 param_names["Black"] = _("Black")356 param_names["White"] = _("White")357 param_names["Num"] = _("Num")358 param_names["Dist. weight"] = _("Dist. weight")359 param_names["Color"] = _("Color")360 param_names["Variance"] = _("Variance")361 param_names["Threshold"] = _("Threshold")362 param_names["Red Saturation"] = _("Red Saturation")363 param_names["Yellow Saturation"] = _("Yellow Saturation")364 param_names["Factor"] = _("Factor")365 param_names["Source color"] = _("Source color")366 param_names["Threshold"] = _("Threshold")367 param_names["Amplitude"] = _("Amplitude")368 param_names["Frequency"] = _("Frequency")369 param_names["Rotate"] = _("Rotate")370 param_names["Tilt"] = _("Tilt")371 param_names["Center Correct"] = _("Center Correct")372 param_names["Edges Correct"] = _("Edges Correct")373 param_names["Flip"] = _("Flip")374 param_names["Axis"] = _("Axis")375 param_names["Invert"] = _("Invert")376 param_names["Position"] = _("Position")377 param_names["Edge Lightning"] = _("Edge Lightning")378 param_names["Edge Brightness"] = _("Edge Brightness")379 param_names["Non-Edge Brightness"] = _("Non-Edge Brightness")380 param_names["Spatial"] = _("Spatial")381 param_names["Temporal"] = _("Temporal")382 param_names["Amount"] = _("Amount")383 param_names["Size"] = _("Size")384 param_names["Border width"] = _("Border width")385 param_names["Phase Incr."] = _("Phase Incr.")386 param_names["Zoom"] = _("Zoom")387 param_names["Freeze Frame"] = _("Freeze Frame")388 param_names["Freeze After"] = _("Freeze After")389 param_names["Freeze Before"] = _("Freeze Before")390 param_names["Angle"] = _("Angle")391 param_names["transition.geometry"] = _("transition.geometry")392 param_names["Shear X"] = _("Shear X")393 param_names["Shear Y"] = _("Shear Y")394 param_names["transition.geometry"] = _("transition.geometry")395 param_names["transition.geometry"] = _("transition.geometry")396 param_names["Left"] = _("Left")397 param_names["Right"] = _("Right")398 param_names["Top"] = _("Top")399 param_names["Bottom"] = _("Bottom")400 param_names["Invert"] = _("Invert")401 param_names["Blur"] = _("Blur")402 param_names["Opacity"] = _("Opacity")403 param_names["Opacity"] = _("Opacity")404 param_names["Rotate X"] = _("Rotate X")405 param_names["Rotate Y"] = _("Rotate Y")406 param_names["Rotate Z"] = _("Rotate Z")407 # added 0.8408 param_names["Edge Mode"] = _("Edge Mode")409 param_names["Sel. Space"] = _("Sel. Space")410 param_names["Operation"] = _("Operation")411 param_names["Hard"] = _("Hard")412 param_names["R/A/Hue"] = _("R/A/Hue")413 param_names["G/B/Chromae"] = _("G/B/Chroma")414 param_names["B/I/I"] = _("B/I/I")415 param_names["Supress"] = _("Supress")416 param_names["Horizontal"] = _("Horizontal")417 param_names["Vertical"] = _("Vertical")418 param_names["Type"] = _("Type")419 param_names["Edge"] = _("Edge")420 param_names["Dot Radius"] = _("Dot Radius")421 param_names["Cyan Angle"] = _("Cyan Angle")422 param_names["Magenta Angle"] = _("Magenta Angle")423 param_names["Yellow Angle"] = _("Yellow Angle")424 param_names["Levels"] = _("Levels")425 param_names["Matrix Type"] = _("Matrix Type")426 param_names["Aspect"] = _("Aspect")427 param_names["Center Size"] = _("Center Size")428 param_names["Azimuth"] = _("Azimuth")429 param_names["Lightness"] = _("Lightness")430 param_names["Bump Height"] = _("Bump Height")431 param_names["Gray"] = _("Gray")432 param_names["Split Preview"] = _("Split Preview")433 param_names["Source on Left"] = _("Source on Left")434 param_names["Lightness"] = _("Lightness")435 param_names["Input black level"] = _("Input black level")436 param_names["Input white level"] = _("Input white level")437 param_names["Black output"] = _("Black output")438 param_names["White output"] = _("White output")439 param_names["Red"] = _("Red")440 param_names["Green"] = _("Green")441 param_names["Blue"] = _("Blue")442 param_names["Action"] = _("Action")443 param_names["Keep Luma"] = _("Keep Luma")444 param_names["Luma Formula"] = _("Luma Formula")445 param_names["Effect"] = _("Effect")446 param_names["Sharpness"] = _("Sharpness")447 param_names["Blend Type"] = _("Blend Type")448 # added 0.16449 param_names["Key Color"] = _("Key Color")450 param_names["Pre-Level"] = _("Pre-Level")451 param_names["Post-Level"] = _("Post-Level")452 param_names["Slope"] = _("Slope")453 param_names["Luma Band"] = _("Luma Band")454 param_names["Lift"] = _("Lift")455 param_names["Gain"] = _("Gain")456 param_names["Input White Level"] = _("Input White Level")457 param_names["Input Black Level"] = _("Input Black Level")458 param_names["Black Output"] = _("Black Output")459 param_names["White Output"] = _("White Output")460 param_names["Rows"] = _("Rows")461 param_names["Columns"] = _("Columns")462 param_names["Color Temperature"] = _("Color Temperature")463 # param names for compositors464 param_names["Opacity"] = _("Opacity")465 param_names["Shear X"] = _("Shear X")466 param_names["Shear Y"] = _("Shear Y")467 param_names["Distort"] = _("Distort")468 param_names["Opacity"] = _("Opacity")469 param_names["Wipe Type"] = _("Wipe Type")470 param_names["Invert"] = _("Invert")471 param_names["Softness"] = _("Softness")472 param_names["Wipe Amount"] = _("Wipe Amount")473 param_names["Wipe Type"] = _("Wipe Type")474 param_names["Invert"] = _("Invert")475 param_names["Softness"] = _("Softness")476 # Combo options477 global combo_options478 combo_options["Shave"] = _("Shave")479 combo_options["Rectangle"] = _("Rectangle")480 combo_options["Ellipse"] = _("Ellipse")481 combo_options["Triangle"] = _("Triangle")482 combo_options["Diamond"] = _("Diamond")483 combo_options["Shave"] = _("Shave")484 combo_options["Shrink Hard"] = _("Shrink Hard")485 combo_options["Shrink Soft"] = _("Shrink Soft")486 combo_options["Grow Hard"] = _("Grow Hard")487 combo_options["Grow Soft"] = _("Grow Soft")488 combo_options["RGB"] = _("RGB")489 combo_options["ABI"] = _("ABI")490 combo_options["HCI"] = _("HCI")491 combo_options["Hard"] = _("Hard")492 combo_options["Fat"] = _("Fat")493 combo_options["Normal"] = _("Normal")494 combo_options["Skinny"] = _("Skinny")495 combo_options["Ellipsoid"] = _("Ellipsoid")496 combo_options["Diamond"] = _("Diamond")497 combo_options["Overwrite"] = _("Overwrite")498 combo_options["Max"] = _("Max")499 combo_options["Min"] = _("Min")500 combo_options["Add"] = _("Add")501 combo_options["Subtract"] = _("Subtract")502 combo_options["Green"] = _("Green")503 combo_options["Blue"] = _("Blue")504 combo_options["Sharper"] = _("Sharper")505 combo_options["Fuzzier"] = _("Fuzzier")506 combo_options["Luma"] = _("Luma")507 combo_options["Red"] = _("Red")508 combo_options["Green"] = _("Green")509 combo_options["Blue"] = _("Blue")510 combo_options["Add Constant"] = _("Add Constant")511 combo_options["Change Gamma"] = _("Change Gamma")512 combo_options["Multiply"] = _("Multiply")513 combo_options["XPro"] = _("XPro")514 combo_options["OldPhoto"] = _("OldPhoto")515 combo_options["Sepia"] = _("Sepia")516 combo_options["Heat"] = _("Heat")517 combo_options["XRay"] = _("XRay")518 combo_options["RedGreen"] = _("RedGreen")519 combo_options["YellowBlue"] = _("YellowBlue")520 combo_options["Esses"] = _("Esses")521 combo_options["Horizontal"] = _("Horizontal")522 combo_options["Vertical"] = _("Vertical")523 combo_options["Shadows"] = _("Shadows")524 combo_options["Midtones"] = _("Midtones")525 combo_options["Highlights"] = _("Highlights")...
interpolate_core_collapse_timescale.py
Source:interpolate_core_collapse_timescale.py
1import numpy as np2from sidmpy.core_collapse_timescale import fraction_collapsed_halos, fraction_collapsed_halos_pool3from scipy.interpolate import RegularGridInterpolator4from scipy.interpolate import interp1d5import pickle6from multiprocess.pool import Pool7class InterpolatedCollapseTimescale(object):8 def __init__(self, points, values, param_names, param_arrays):9 self.param_names = param_names10 self.param_ranges = []11 self.param_ranges_dict = {}12 for i, param in enumerate(param_arrays):13 ran = [param[0], param[-1]]14 self.param_ranges.append(ran)15 self.param_ranges_dict[param_names[i]] = ran16 self._interp_function = RegularGridInterpolator(points, values,17 bounds_error=False, fill_value=None)18 @classmethod19 def fromParamArray(self, m1, m2, cross_section_model, param_names, param_arrays, params_fixed={},20 kwargs_fraction={}, nproc=8):21 param_names = param_names22 param_ranges = []23 param_ranges_dict = {}24 for i, param in enumerate(param_arrays):25 ran = [param[0], param[-1]]26 param_ranges.append(ran)27 param_ranges_dict[param_names[i]] = ran28 print('param_names: ', param_names)29 print('n params: ', len(param_names))30 print('n sample arrays: ', len(param_arrays))31 # redshift is always last32 if len(param_arrays) == 2:33 args_list = []34 points = (param_arrays[0], param_arrays[1])35 n_total = len(param_arrays[0]) * len(param_arrays[1])36 print('n total: ', n_total)37 for p1 in param_arrays[0]:38 for redshift in param_arrays[1]:39 kw = {param_names[0]: p1}40 kw.update(params_fixed)41 kwargs_fraction['redshift'] = redshift42 cross_model = cross_section_model(**kw)43 new = (m1, m2, cross_model, kwargs_fraction['redshift'], kwargs_fraction['timescale_factor'])44 args_list.append(new)45 shape = (len(param_arrays[0]), len(param_arrays[1]))46 elif len(param_arrays) == 3:47 args_list = []48 points = (param_arrays[0], param_arrays[1], param_arrays[2])49 n_total = len(param_arrays[0]) * len(param_arrays[1]) * len(param_arrays[2])50 print('n total: ', n_total)51 for p1 in param_arrays[0]:52 for p2 in param_arrays[1]:53 for redshift in param_arrays[2]:54 # if counter % step == 0:55 # print(str(np.round(100 * counter / n_total, 1)) + '% ')56 kw = {param_names[0]: p1, param_names[1]: p2}57 kw.update(params_fixed)58 kwargs_fraction['redshift'] = redshift59 cross_model = cross_section_model(**kw)60 new = (m1, m2, cross_model, kwargs_fraction['redshift'], kwargs_fraction['timescale_factor'])61 args_list.append(new)62 shape = (len(param_arrays[0]), len(param_arrays[1]), len(param_arrays[2]))63 elif len(param_arrays) == 4:64 points = (param_arrays[0], param_arrays[1], param_arrays[2], param_arrays[3])65 n_total = len(param_arrays[0]) * len(param_arrays[1]) * len(param_arrays[2]) * len(param_arrays[3])66 print('n total: ', n_total)67 args_list = []68 for p1 in param_arrays[0]:69 for p2 in param_arrays[1]:70 for p3 in param_arrays[2]:71 for redshift in param_arrays[3]:72 kw = {param_names[0]: p1, param_names[1]: p2, param_names[2]: p3}73 kw.update(params_fixed)74 kwargs_fraction['redshift'] = redshift75 cross_model = cross_section_model(**kw)76 new = (m1, m2, cross_model, kwargs_fraction['redshift'], kwargs_fraction['timescale_factor'])77 args_list.append(new)78 pool = Pool(nproc)79 values = pool.map(fraction_collapsed_halos_pool, args_list)80 pool.close()81 shape = (len(param_arrays[0]), len(param_arrays[1]), len(param_arrays[2]), len(param_arrays[3]))82 elif len(param_arrays) == 5:83 points = (param_arrays[0], param_arrays[1], param_arrays[2], param_arrays[3], param_arrays[4])84 n_total = len(param_arrays[0]) * len(param_arrays[1]) * len(param_arrays[2]) * len(param_arrays[3]) * len(param_arrays[4])85 print('n total: ', n_total)86 args_list = []87 for p1 in param_arrays[0]:88 for p2 in param_arrays[1]:89 for p3 in param_arrays[2]:90 for p4 in param_arrays[3]:91 for redshift in param_arrays[4]:92 # if counter % step == 0:93 # print(str(np.round(100 * counter / n_total, 1)) + '% ')94 kw = {param_names[0]: p1, param_names[1]: p2, param_names[2]: p3, param_names[3]: p4}95 kw.update(params_fixed)96 kwargs_fraction['redshift'] = redshift97 cross_model = cross_section_model(**kw)98 new = (99 m1, m2, cross_model, kwargs_fraction['redshift'], kwargs_fraction['timescale_factor'])100 args_list.append(new)101 shape = (len(param_arrays[0]), len(param_arrays[1]), len(param_arrays[2]), len(param_arrays[3]),102 len(param_arrays[4]))103 elif len(param_arrays) == 6:104 points = (param_arrays[0], param_arrays[1], param_arrays[2], param_arrays[3], param_arrays[4], param_arrays[5])105 n_total = len(param_arrays[0]) * len(param_arrays[1]) * len(param_arrays[2]) * len(param_arrays[3]) * len(106 param_arrays[4]) * len(param_arrays[5])107 print('n total: ', n_total)108 args_list = []109 for p1 in param_arrays[0]:110 for p2 in param_arrays[1]:111 for p3 in param_arrays[2]:112 for p4 in param_arrays[3]:113 for p5 in param_arrays[4]:114 for redshift in param_arrays[5]:115 # if counter % step == 0:116 # print(str(np.round(100 * counter / n_total, 1)) + '% ')117 kw = {param_names[0]: p1, param_names[1]: p2, param_names[2]: p3, param_names[3]: p4,118 param_names[4]: p5}119 kw.update(params_fixed)120 kwargs_fraction['redshift'] = redshift121 cross_model = cross_section_model(**kw)122 new = (123 m1, m2, cross_model, kwargs_fraction['redshift'],124 kwargs_fraction['timescale_factor'])125 args_list.append(new)126 pool = Pool(nproc)127 values = pool.map(fraction_collapsed_halos_pool, args_list)128 pool.close()129 shape = (len(param_arrays[0]), len(param_arrays[1]), len(param_arrays[2]), len(param_arrays[3]),130 len(param_arrays[4]), len(param_arrays[5]))131 elif len(param_arrays) == 7:132 points = (param_arrays[0], param_arrays[1], param_arrays[2], param_arrays[3], param_arrays[4],133 param_arrays[5], param_arrays[6])134 n_total = len(param_arrays[0]) * len(param_arrays[1]) * len(param_arrays[2]) * len(param_arrays[3]) * len(135 param_arrays[4]) * len(param_arrays[5] * len(param_arrays[6]))136 print('n total: ', n_total)137 args_list = []138 for p1 in param_arrays[0]:139 for p2 in param_arrays[1]:140 for p3 in param_arrays[2]:141 for p4 in param_arrays[3]:142 for p5 in param_arrays[4]:143 for timescale_factor in param_arrays[5]:144 for redshift in param_arrays[6]:145 kw = {param_names[0]: p1, param_names[1]: p2, param_names[2]: p3, param_names[3]: p4,146 param_names[4]: p5}147 kw.update(params_fixed)148 kwargs_fraction['redshift'] = redshift149 kwargs_fraction['timescale_factor'] = timescale_factor150 cross_model = cross_section_model(**kw)151 new = (152 m1, m2, cross_model, kwargs_fraction['redshift'],153 kwargs_fraction['timescale_factor'])154 args_list.append(new)155 pool = Pool(nproc)156 values = pool.map(fraction_collapsed_halos_pool, args_list)157 pool.close()158 shape = (len(param_arrays[0]), len(param_arrays[1]), len(param_arrays[2]), len(param_arrays[3]),159 len(param_arrays[4]), len(param_arrays[5]), len(param_arrays[6]))160 else:161 raise Exception('only 2, 3, 4 and 5D interpolations implemented')162 return InterpolatedCollapseTimescale(points, values, param_names, param_arrays, shape)163 def __call__(self, *args):164 return np.squeeze(self._interp_function(tuple(args)))165def interpolate_collapse_fraction(fname, cross_section_class, param_names, param_arrays, params_fixed, m1,166 kwargs_collapse_fraction, nproc):167 interp_timescale = InterpolatedCollapseTimescale(m1, m1 * 1.05, cross_section_class,168 param_names, param_arrays, params_fixed, kwargs_collapse_fraction, nproc=nproc)169 f = open('interpolated_collapse_fraction_'+fname, 'wb')170 pickle.dump(interp_timescale, f)171 f.close()172# from sidmpy.CrossSections.resonant_tchannel import ExpResonantTChannel173# # norm, v_ref, v_res, w_res, res_amplitude174# param_names = ['norm', 'v_ref', 'v_res', 'w_res', 'res_amplitude', 'timescale_factor', 'redshift']175# cross_model = ExpResonantTChannel176#177# output_folder = ''178# nproc = 50179# params_fixed = {}180# kwargs_collapse_fraction = {}181# z_array = [0.2, 0.45, 0.7, 0.95]182# tarray = [10/3, 15/3, 20/3]183# param_arrays = [np.linspace(1, 10.0, 9), np.linspace(1, 50.0, 20), np.linspace(1, 40, 20),184# np.linspace(1, 5.0, 5), np.linspace(1.0, 100, 40), tarray, z_array]185# n_total = 1186# for parr in param_arrays:187# n_total *= len(parr)188# print('n_total: ', n_total); a=input('continue')189# fname = output_folder + 'logM68_expresonanttchannel'190# m1 = 10 ** 7191# interpolate_collapse_fraction(fname, cross_model, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction, nproc=nproc)192# fname = output_folder + 'logM89_expresonanttchannel'193# m1 = 10 ** 8.5194# interpolate_collapse_fraction(fname, cross_model, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction, nproc=nproc)195#196# fname = output_folder + 'logM910_expresonanttchannel'197# m1 = 10 ** 9.5198# interpolate_collapse_fraction(fname, cross_model, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction, nproc=nproc)199# from sidmpy.CrossSections.tchannel import TChannel200# param_names = ['norm', 'v_ref']201# n = 50202# params_fixed = {}203# kwargs_collapse_fraction = {'redshift': 0.5, 'timescale_factor': 20.0}204# param_arrays = [np.linspace(0.5, 60.0, n), np.linspace(1.0, 40, n)]205# fname = 'logM68_tchannel'206# m1 = 10 ** 7207# interpolate_collapse_fraction(fname, TChannel, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction)208#209# fname = 'logM89_tchannel'210# m1 = 5 * 10 ** 8211# interpolate_collapse_fraction(fname, TChannel, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction)212#213# fname = 'logM910_tchannel'214# m1 = 5 * 10 ** 9215# interpolate_collapse_fraction(fname, TChannel, param_names, param_arrays, params_fixed, m1, kwargs_collapse_fraction)...
Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!