Best Python code snippet using autotest_python
ingest.py
Source:ingest.py
1"""2Ingest construct used to perform basic schema-driven validations3and to add some standard meta data values to be used downstream4"""56import argparse7import boto38import glob9import json10import numpy as np11import pandas as pd12import re1314from abc import ABCMeta, abstractmethod15from datetime import datetime16from pyspark import SparkConf17from pyspark.sql import DataFrame, SparkSession18from pyspark.sql import functions as F19from pyspark.sql import types as T2021from spark_process_common.job_config import JobConfig22from spark_process_common.transforms import normalize_columns, snek2324class BaseIngest(metaclass=ABCMeta):25 """26 Base class for ingest. This provides the controller pattern27 structure to be sure that the extract and save functions28 are properly called.2930 This class also provides the implementation for loading31 the json configuration.32 """3334 def __init__(self, spark_config: dict, json_config: dict=None):35 """36 Default constructor.3738 Parameters:3940 spark_config - contains the configuration settings in dictionary form41 to pass to the SparkSession builder.4243 example:44 {45 "spark.sql.hive.convertMetastoreOrc": "true",46 "spark.sql.files.ignoreMissingFiles": "true",47 "spark.sql.adaptive.enabled": "true",48 "spark.sql.hive.verifyPartitionPath": "false",49 "spark.sql.orc.filterPushdown": "true",50 "spark.sql.sources.partitionOverwriteMode": "DYNAMIC"51 }5253 json_config - contains the config_bucket and config_path parameters54 for development and testing usage that will be collected from the55 spark-submit arguments in cluster mode.5657 example:58 {59 "config_bucket": "/vagrant/mill_profitability_ingest/job-config/",60 "config_path": "extract_pbm_as400_urbcrb_invoices.json"61 }62 """63 self.__extract_dttm = None6465 if json_config:66 with open(json_config.get("config_bucket") + json_config.get("config_path")) as f:67 self.__config = json.loads(f.read())68 else:69 parser = argparse.ArgumentParser()70 parser.add_argument(71 'config_bucket',72 help='Used by boto to specify S3 bucket to json config files.',73 default='wrktdtransformationprocessproddtl001', 74 type=str)75 parser.add_argument(76 'config_path',77 help='Used by boto to specify S3 path to json config file.', 78 type=str) 79 80 args = parser.parse_args()81 config_bucket = args.config_bucket82 config_path = args.config_path8384 # Configuration85 job_config = JobConfig(config_bucket, config_path)86 self.__config = job_config.get()8788 # SparkSession Configuration89 spark_conf = SparkConf()90 for key, value in spark_config.items():91 spark_conf.set(key, value)92 self.__spark_conf = spark_conf9394 # Build the SparkSession here95 self.__spark = (96 SparkSession.builder.appName(self.__config.get('app_name'))97 .config(conf=self.__spark_conf)98 .enableHiveSupport()99 .getOrCreate()100 )101102 def __enter__(self):103 """104 Function called when entering a with block.105106 Not building the SparkSession her so that in dev107 we can skip the with statement which causes the closing108 of the metastore connection, requiring a restart of the109 jupyter kernel after each execution. 110 """111 return self112113 def __exit__(self, type, value, traceback):114 """115 Function called when exiting a with block.116117 Ensure the spark session is stopped when exiting a with 118 statement119 """120 self.__spark.stop()121122 def get_spark(self):123 """124 Helper function to retrieve the generated SparkSession.125 """126 return self.__spark127128 def get_config(self):129 """130 Helper function to retrieve a copy of the JSON Configuration.131 """132 return self.__config.copy()133134 @staticmethod135 def __reorder_columns_for_partitioning(column_list: list, partitioned_by: list) -> list():136 """137 Helper function to reorder partition columns to be last (required by Spark / Hive).138 """139 _column_list = column_list.copy()140 _partitioned_by = partitioned_by.copy()141 non_partition_columns = [i for i in _column_list if i not in _partitioned_by]142 _column_list = non_partition_columns + _partitioned_by143144 return _column_list145146 @staticmethod147 def __validate_primary_key(data_frame: DataFrame, key_columns: list, data_source: str):148 """149 Function to validate that source data conforms to the proper uniquness, as defined150 by the list of keys in the job-config json.151 """152 if key_columns:153154 df = data_frame.groupBy(*key_columns).count().alias('count')155 if not df.where('count > 1').rdd.isEmpty():156 df = df.where('count > 1')157 strData = str(df.limit(5).collect())158 cntDups = df.count()159 errMessage = "Source violates configured primary key. "160 if cntDups == 1:161 errMessage = "{} There is {} duplicate row".format(errMessage, str(cntDups))162 else:163 errMessage = "{} There are {} duplicate rows".format(errMessage, str(cntDups))164 165 errMessage = "{} in {}. ".format(errMessage, data_source)166 errMessage = "{} Key Columns: {} ".format(errMessage, str(key_columns))167 errMessage = "{} Duplicate Data: {}".format(errMessage, strData)168 169 170 raise IngestPrimaryKeyViolationError(errMessage)171 172173 def __stage_source(self, hive_db: str, hive_table: str, partitioned_by: list, column_list: list, data_frame: DataFrame):174 """175 Function called to save the ingested Dataframe into Hive stage table -> 'SOURCE' partition176 """177 spark = self.get_spark()178 _column_list = column_list.copy()179 _partitioned_by = []180 if partitioned_by:181 _partitioned_by = partitioned_by.copy()182183 _partitioned_by.append('iptmeta_record_origin')184 _column_list.append('iptmeta_record_origin')185 _column_list = [snek(c) for c in _column_list]186187 data_frame = normalize_columns(data_frame)188 data_frame = data_frame.withColumn('iptmeta_record_origin', F.lit('SOURCE'))189 190 source_view_name = "{}_{}_src_temp".format(hive_db, hive_table)191 data_frame.createOrReplaceTempView(source_view_name)192 193 columns = ", ".join([i for i in _column_list])194 partitions = ", ".join(_partitioned_by)195196 dml = (197 "TRUNCATE TABLE stage.{}_{}"198 .format(hive_db, hive_table)199 )200 spark.sql(dml)201202 # drop all partitions203 dml = "ALTER TABLE stage.{}_{} ".format(hive_db, hive_table)204 dml = "{} DROP IF EXISTS PARTITION (iptmeta_record_origin='TARGET')".format(dml)205 spark.sql(dml)206207 dml = "ALTER TABLE stage.{}_{} ".format(hive_db, hive_table)208 dml = "{} DROP IF EXISTS PARTITION (iptmeta_record_origin='SOURCE')".format(dml)209 spark.sql(dml)210211 dml = "INSERT OVERWRITE TABLE stage.{}_{}".format(hive_db, hive_table)212 dml = "{} PARTITION({})".format(dml, partitions)213 dml = "{} SELECT {} FROM {}".format(dml, columns, source_view_name) 214 215 spark.sql(dml)216217 def __stage_target(self, hive_db: str, hive_table: str, key_columns: list, partitioned_by: list, column_list: list):218 """219 Function to stage target rows for keys not found in the source. Prepares stage for insert overwrite.220 """221 spark = self.get_spark()222 223 joins = ""224 joins = " AND ".join(["tgt.{} = src.{}".format(i,i) for i in key_columns])225 _partitioned_by = []226 filter_partitions = []227 if partitioned_by:228 _partitioned_by = partitioned_by.copy()229230 partition_columns = ", ".join(["tgt.{}".format(i) for i in _partitioned_by])231232 # get the list of partition values from the target table that have matching primary keys233 dml = "SELECT DISTINCT {} ".format(partition_columns)234 dml = "{} FROM {}.{} as tgt".format(dml, hive_db, hive_table)235 dml = "{} INNER JOIN stage.{}_{} as src ON {} ".format(dml, hive_db, hive_table, joins)236237 part_list_tgt = spark.sql(dml)238 239 # get list of partitions from staging table 240 dml = (241 "SELECT DISTINCT {} FROM stage.{}_{} as tgt"242 .format(partition_columns, hive_db, hive_table)243 ) 244 245 part_list_src = spark.sql(dml)246 247 # final list of partition columns that we will select from 248 part_list = (249 part_list_src250 .union(part_list_tgt)251 .distinct() 252 )253 part_columns = part_list.columns 254 255 filter_partitions = []256 for row in part_list.collect():257 row_filter = " AND ".join(["tgt.{} = '{}'".format(part_columns[x],y) for x,y in enumerate(row)])258 filter_partitions.append("({})".format(row_filter))259 260 filter_partitions = " OR ".join(filter_partitions) 261262 columns = ", ".join(["tgt.{}".format(i) for i in column_list])263 _partitioned_by.append('iptmeta_record_origin')264 partitions = ", ".join(_partitioned_by)265 266 dml = "INSERT OVERWRITE TABLE stage.{}_{}".format(hive_db, hive_table)267 dml = "{} PARTITION({})".format(dml, partitions)268 dml = "{} SELECT {}, 'TARGET' as iptmeta_record_origin".format(dml, columns)269 dml = "{} FROM {}.{} as tgt".format(dml, hive_db, hive_table)270 dml = "{} LEFT OUTER JOIN stage.{}_{} as src ON {} WHERE src.{} IS null".format(dml, hive_db, hive_table, joins, key_columns[0])271 if filter_partitions:272 dml = "{} AND ({})".format(dml, filter_partitions)273 274 spark.sql(dml)275276 def __write(self, hive_db: str, hive_table: str, key_columns: list, partitioned_by: list, 277 column_list: list, partition_overwrite: bool, force_truncate: bool):278 """279 Function to perform the final write to target from stage280 """281 spark = self.get_spark()282283 # Table should be completely overwritten if no primary key exists284 # Insert Overwrite may not overwrite all partitions285 if (partitioned_by and not key_columns and not partition_overwrite) or force_truncate:286 dml = (287 "TRUNCATE TABLE {}.{}"288 .format(hive_db, hive_table)289 )290 spark.sql(dml)291292 if hive_db == 'master_data_sources':293 column_list.append('iptmeta_record_origin')294295 columns = [snek(c).replace('__', '_') for c in column_list]296 columns = ", ".join([i for i in columns])297298 if partitioned_by: 299 partitions = ", ".join(partitioned_by)300 dml = (301 "INSERT OVERWRITE TABLE {}.{} PARTITION({}) SELECT {} FROM stage.{}_{}"302 .format(hive_db, hive_table, partitions, columns, hive_db, hive_table)303 )304 else:305 dml = (306 "INSERT OVERWRITE TABLE {}.{} SELECT {} FROM stage.{}_{}"307 .format(hive_db, hive_table, columns, hive_db, hive_table)308 ) 309310 spark.sql(dml)311312 def __archive(self, archive_path: str, data_frame: DataFrame=None):313 """314 Function to saves input files into a timestamped archive directory for troubleshooting or re-processing 315 """316 archive_dttm = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')317 archive_path = "{}/{}".format(archive_path, archive_dttm)318319 data_frame = normalize_columns(data_frame)320 data_frame.write.orc(path=archive_path, mode='append')321322 def pre_processing(self, data_frame: DataFrame) -> DataFrame:323 """324 Function called to perform any necessary pre processing.325 Meant to be overriden when necessary.326 """ 327 return data_frame328329 @staticmethod330 def add_meta_data_columns(data_frame: DataFrame) -> DataFrame:331 """332 Append extracted DataFrames with some common meta data fields333 334 iptmeta_extract_dttm: A single timestamp to help with troubleshooting335 iptmeta_corrupt_record: to be used / populated as needed during validation336 iptmeta_record_origin: An identifier to partition stage on, used to distinquish 337 between new rows inbound compared to target rows to be merged338 """ 339340 if 'iptmeta_corrupt_record' not in data_frame.columns:341 data_frame = data_frame.withColumn('iptmeta_corrupt_record', F.lit(None).cast(T.StringType()))342343 data_frame = data_frame.withColumn('iptmeta_extract_dttm', F.current_timestamp())344345 return data_frame346347 @abstractmethod348 def extract(self, config: dict) -> DataFrame:349 # Config requirements may be different among child class implementations 350 raise NotImplementedError351352 def execute(self):353 """354 Entry point for the Ingest process. This function controls the355 ordering of extract and save process.356357 If key_columns exist, then data will be merged. Target records will be updated358 if a match on primary key is found. Otherwise, new records will be inserted.359 360 If key_columns do not exist, the target data will be truncated and overwritten.361 362 If key_columns do not exist, partitioned_by columns exist, and non_key_partition_overwrite, then363 only the partitions that exist in the source will be overwritten. 364 The entire target table will not be truncated in this case.365366 If force_truncate is true, the target data will be truncated and overwritten.367 """368369 config = self.get_config()370371 hive_db = config["hive_db"]372 hive_table = config["hive_table"]373374 key_columns = config.get("key_columns")375 partitioned_by = config.get("partitioned_by")376 archive_path = config['archive_path']377 force_truncate = config.get('force_truncate')378379 read_options = config.get("read_options")380 path = config.get("path") 381382 if read_options or path:383 data_source = "data source: {} {}".format(path, str(read_options))384 else:385 data_source = "sqoop.{}_{} table".format(hive_db, hive_table)386387 # non_key_partition_overwrite argument will prevent the target table from being truncated.388 non_key_partition_overwrite = config.get("non_key_partition_overwrite")389 390 if key_columns or (not partitioned_by):391 non_key_partition_overwrite = False392393 df = self.extract(config)394395 # Return with no error if the sheet is optional and any Excel file does not contain it396 sheet_optional = config.get('sheet_optional')397 if sheet_optional:398 if not df:399 return400401 if df.count() == 0:402 # Return with no error if empty data frame is an acceptable outcome403 success_if_no_records = config.get('success_if_no_records')404 if success_if_no_records:405 return406 # Otherwise throw error407 else:408 errMessage = "No records exist in {}.".format(data_source)409 raise IngestNoDataError(errMessage)410411 # Perform any pre-processing that may be necessary412 df = self.pre_processing(data_frame=df)413414 # Add meta data fields to source output415 df = self.add_meta_data_columns(df)416417 column_list = df.columns418 if partitioned_by:419 column_list = self.__reorder_columns_for_partitioning(420 column_list=column_list, 421 partitioned_by=partitioned_by422 )423424 df.repartition(*partitioned_by)425 else:426 df.coalesce(10)427428 # Insert Overwrite into stage.(hive_db)_(hive_table) from source -> SOURCE Partition429 self.__stage_source(430 hive_db=hive_db,431 hive_table=hive_table,432 partitioned_by=partitioned_by,433 column_list=column_list,434 data_frame=df435 )436437 if key_columns:438 self.__validate_primary_key(data_frame=df, key_columns=key_columns, data_source=data_source)439440 if key_columns and not force_truncate:441 # Insert Overwrite into stage.(hive_db)_(hive_table) from source -> Target Partition442 self.__stage_target(443 hive_db=hive_db,444 hive_table=hive_table,445 key_columns=key_columns,446 partitioned_by=partitioned_by,447 column_list=column_list448 )449450 self.__write(451 hive_db=hive_db,452 hive_table=hive_table,453 key_columns=key_columns,454 partitioned_by=partitioned_by,455 column_list=column_list,456 partition_overwrite=non_key_partition_overwrite, 457 force_truncate=force_truncate458 )459 460 # Write orc file to s3 to retain archive of ingest #461 self.__archive(462 archive_path=archive_path,463 data_frame=df464 )465466467class IngestCSV(BaseIngest):468 """469 Implementation of the BaseIngest class that provides for ingesting470 CSV files.471472 The schema utilized by this class follows the PySpark StructType schema.473 """474475 def extract(self, config: dict) -> DataFrame:476 """477 Entry point for extract logic.478479 Apply schema and add meta data.480 """481 path = config['source_path']482 read_options = config.get('read_options')483 schema = config.get('schema', dict())484485 _read_options = read_options.copy()486487 if schema:488 # StructType.fromJson method looks for 'fields' key ..489 schema_struct = T.StructType.fromJson(schema) 490 schema_struct.add(T.StructField('iptmeta_corrupt_record', T.StringType(), True))491 _read_options['schema'] = schema_struct492 else:493 _read_options['inferSchema'] = 'true'494495 df = self.get_spark().read.csv(path, **_read_options)496 #df.cache()497498 return df 499500class IngestFixedLengthFile(BaseIngest):501 """502 Implementation of the BaseIngest class that provides for ingesting503 fixed length files.504505 sample schema json:506 "schema": {507 "fields": [508 {"name": "Record_Type", "type": "string", "nullable": true, "metadata": "", "start": 1, "length": 3},509 {"name": "Currency_Code_From", "type": "string", "nullable": true, "metadata": "", "start": 4, "length": 3},510 {"name": "Currency_Code_To", "type": "string", "nullable": true, "metadata": "", "start": 7, "length": 3},511 {"name": "Effective_Date", "type": "string", "nullable": true, "metadata": "", "start": 10, "length": 8},512 {"name": "Conversion_Rate_Multiplier", "type": "string", "nullable": true, "metadata": "", "start": 18, "length": 15},513 {"name": "Conversion_Rate_Divisor", "type": "string", "nullable": true, "metadata": "", "start": 33, "length": 15},514 {"name": "Extract_Date", "type": "string", "nullable": true, "metadata": "", "start": 48, "length": 8},515 {"name": "Extract_Time", "type": "string", "nullable": true, "metadata": "", "start": 56, "length": 6}516 ]517 }518 """519 @abstractmethod520 def filter_df(self, data_frame: DataFrame) -> DataFrame:521 raise NotImplementedError522523 @abstractmethod524 def process_columns(self, data_frame: DataFrame) -> DataFrame:525 raise NotImplementedError526527 def extract(self, config: dict) -> DataFrame:528 """529 Entry point for extract logic.530531 Use schema provided to create the dataframe utilizing532 field names and column starts and lengths. Also adds meta data.533 """534 path = config['source_path']535 schema = config['schema']536537 if not schema:538 raise IngestSchemaError()539540 df = self.get_spark().read.text(path)541 df = self.filter_df(df)542543 for field in schema["fields"]:544 df = df.withColumn(field["name"], F.col("value").substr(field["start"], field["length"]).cast(field["type"]))545546 df = self.process_columns(df)547 #df.cache()548549 return df 550551class IngestExcelFile(BaseIngest):552 """553 """554555 def __get_pd_datatype(self, field_type: str) -> np.dtype:556 """557 Function to map Spark SQL types to Numpy types558 """ 559 return {560 'string': str,561 'timestamp': np.dtype('datetime64[ns]'),562 'double': np.float64,563 'float': np.float64,564 'integer': np.int64,565 'bigint': np.int64,566 }[field_type]567568 def __create_pd_excel_schema(self, schema: dict) -> dict:569 """570 Function to generate Pandas schema from JSON Config schema571 """572 pd_schema = {}573574 #Read json schema here and generate pandas schema575 for index, field in enumerate(schema["fields"]):576 pd_schema[index] = self.__get_pd_datatype(field["type"])577578 return pd_schema579580 @staticmethod581 def __convert_nan_to_null(data_frame: DataFrame) -> DataFrame:582 """583 Function to convert NaN string values to NULL584 """ 585 data_frame = data_frame.select([586 F.when(F.isnan(c), None).otherwise(F.col(c)).alias(c) if t in ("double", "float", "integer", "bigint") else c587 for c, t in data_frame.dtypes588 ])589590 data_frame = data_frame.select([591 F.when(F.col(c) == 'nan', None).otherwise(F.col(c)).alias(c) if t == "string" else c592 for c, t in data_frame.dtypes593 ]) 594595 return data_frame596597 @staticmethod598 def __validate_type_conversions(data_frame: DataFrame) -> DataFrame:599 """600 Function to find and record value errors601 """602 validate_column_dtypes = False603604 for c, t in data_frame.dtypes:605 if t in ("double", "float", "int", "bigint"):606 validate_column_dtypes = True607608 if validate_column_dtypes:609 predicate = " or ".join(610 ['isnan({})'.format(c) for c, t in data_frame.dtypes if t in ("double", "float", "int", "bigint")]611 )612 613 iptmeta_corrupt_record = "case when {} then concat_ws('|', *) else NULL end".format(predicate)614 data_frame = data_frame.withColumn("iptmeta_corrupt_record", F.expr(iptmeta_corrupt_record))615616 return data_frame617618 @staticmethod619 def __validate_key_columns(data_frame: DataFrame, key_columns: list) -> DataFrame:620 """621 Function to find and record null key values622 """623 if key_columns:624 predicate = " or ".join(['isnull({})'.format(i) for i in key_columns])625 iptmeta_corrupt_record = "case when {} then concat_ws('|', *) else NULL end".format(predicate)626627 data_frame = data_frame.withColumn("iptmeta_corrupt_record", F.expr(iptmeta_corrupt_record))628 629 return data_frame630631 @staticmethod632 def __get_file_paths(path: str) -> list:633 """634 Gets a list of files based on the given path. This method635 handles S3 file system paths and local file system paths. Paths can be636 an absolute path to a single file or a globbed folder/file path.637 """638 # The list of Excel files we'll be importing into a dataframe639 file_paths = []640641 # If the path is an S3 filesystem, then use Boto3, the AWS SDK642 if path.startswith('s3://') and (path.endswith('*.xlsx') or path.endswith('/')):643 # Remove everything after last forward slash644 path = path.rsplit('/', 1)[0] + '/';645 # Get the bucket name from the path646 bucket_name = re.search('s3://(.*?)/', path).group(1)647 # Get the folder prefix648 prefix = re.search('s3://' + bucket_name + '/(.*)', path).group(1)649 # Get the S3 bucket650 bucket = boto3.resource('s3').Bucket(bucket_name)651 # Build a list of file paths from S3 Bucket652 for obj in bucket.objects.filter(Prefix=prefix):653 # Skip if object path is the parent folder since we only want file paths654 if obj.key == prefix or not obj.key.endswith(".xlsx"):655 continue656657 file_paths.append("s3://{0}/{1}".format(bucket_name, obj.key))658 elif path.startswith('/') and "*" in path:659 # If path starts with a forward slash then assume we're running locally660 file_paths = glob.glob(path)661 else:662 # An absolute path to a file is given (S3 or local) so just append663 file_paths.append(path)664665 return file_paths666667 def extract(self, config: dict) -> DataFrame:668 """669 """670 read_options = config['read_options']671 schema = config.get('schema')672 _read_options = read_options.copy()673674 self.IMPORT_NANS = [675 '', 'N/A', 'NULL', 'null', 'NaN', 'n/a', 'nan', '#N/A', '#N/A N/A', '#NA', '-1.#IND',676 '-1.#QNAN', '-NaN', '-nan', '1.#IND', '1.#QNAN', 'None', '(blank)', ' '677 ]678 _read_options["na_values"] = self.IMPORT_NANS679 680 if not schema or not schema.get('fields'):681 raise IngestSchemaError()682 else:683 schema_struct = T.StructType.fromJson(schema) 684 pd_schema = self.__create_pd_excel_schema(schema=schema)685 _read_options['dtype'] = pd_schema686 687 # replaces normalize_column_names logic, built-in 'names' attribute is meant for this688 field_names = [i["name"] for i in schema['fields']]689 _read_options["names"] = field_names690 691 # Set range on skip rows if defined in config692 if "skiprows" in _read_options:693 start_row= _read_options["skiprows"]["start"]694 stop_row = _read_options["skiprows"]["stop"]695 696 _read_options["skiprows"] = list(range(start_row, stop_row))697698 file_paths = self.__get_file_paths(_read_options["io"])699700 # Remove i/o read_option since we already have a list of excel files to read701 _read_options.pop("io", None)702703 # If sheet_optional flag is set to True, return False if any Excel file does not contain the sheet704 sheet_optional = config.get('sheet_optional')705 if sheet_optional:706 sheet_name = read_options['sheet_name']707 for file in file_paths:708 xl = pd.ExcelFile(file)709 if sheet_name not in xl.sheet_names:710 return False711712 import_pd_df = pd.concat((713 pd.read_excel(file, **_read_options)714 for file in file_paths))715 716 df = self.get_spark().createDataFrame(import_pd_df, schema_struct)717 df = self.__validate_type_conversions(df)718 df = self.__convert_nan_to_null(df) 719 df = self.__validate_key_columns(df, config.get('key_columns'))720 #df.cache()721722 return df723724725class IngestDbSource(BaseIngest):726727 def extract(self, config: dict) -> DataFrame:728 """729 Function to read data in from sqoop.db in hive.730 """731 hive_db = config['hive_db']732 hive_table = config['hive_table']733734 spark = self.get_spark()735 query = 'SELECT * FROM sqoop.{}_{}'.format(hive_db, hive_table)736 df = spark.sql(query)737738 return df739740741class IngestError(Exception):742 """743 Base class for exceptions in this module.744 """745 pass746747class IngestSchemaError(IngestError):748 """749 Exception raised for ingest schema errors.750 expression -- input expression in which the error occurred751 message -- explanation of the error752 """753754 def __init__(self):755 self.expression = "Schema, with list of fields, is required."756 self.message = "Schema, with list of fields, is required."757758class IngestTypeError(IngestError):759 """Exception raised for ingest type errors.760761 Attributes:762 expression -- input expression in which the error occurred763 message -- explanation of the error764 """765 def __init__(self, ingest_type=None, target_table=None, key_columns=None):766 self.expression = 'ingest_type: {}, target_table: {}, key_columns: {}'.format(767 ingest_type,768 target_table,769 key_columns770 )771772 self.message = "Valid ingest types are 'full' and 'incremental'. \773 If 'incremental', source keys and target table must be supplied "774775class IngestPrimaryKeyViolationError(IngestError):776 """Exception raised for ingest primary key violations.777778 Source violates configured primary key779780 Attributes:781 expression -- input expression in which the error occurred782 message -- explanation of the error783 """784 def __init__(self, errMessage):785 self.expression = errMessage786 self.message = errMessage787 788class IngestNoDataError(IngestError):789 """Exception raised when Sqoop table is empty790791 Attributes:792 expression -- input expression in which the error occurred793 message -- explanation of the error794 """795 def __init__(self, errMessage):796 self.expression = errMessage
...
csv_parser.py
Source:csv_parser.py
...25 def format(self) -> CsvFormat:26 if self.format_model is None:27 self.format_model = CsvFormat.parse_obj(self._format)28 return self.format_model29 def _read_options(self) -> Mapping[str, str]:30 """31 https://arrow.apache.org/docs/python/generated/pyarrow.csv.ReadOptions.html32 build ReadOptions object like: pa.csv.ReadOptions(**self._read_options())33 """34 return {35 **{"block_size": self.format.block_size, "encoding": self.format.encoding},36 **json.loads(self.format.advanced_options),37 }38 def _parse_options(self) -> Mapping[str, str]:39 """40 https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html41 build ParseOptions object like: pa.csv.ParseOptions(**self._parse_options())42 """43 return {44 "delimiter": self.format.delimiter,45 "quote_char": self.format.quote_char,46 "double_quote": self.format.double_quote,47 "escape_char": self.format.escape_char,48 "newlines_in_values": self.format.newlines_in_values,49 }50 def _convert_options(self, json_schema: Mapping[str, Any] = None) -> Mapping[str, Any]:51 """52 https://arrow.apache.org/docs/python/generated/pyarrow.csv.ConvertOptions.html53 build ConvertOptions object like: pa.csv.ConvertOptions(**self._convert_options())54 :param json_schema: if this is passed in, pyarrow will attempt to enforce this schema on read, defaults to None55 """56 check_utf8 = self.format.encoding.lower().replace("-", "") == "utf8"57 convert_schema = self.json_schema_to_pyarrow_schema(json_schema) if json_schema is not None else None58 return {59 **{"check_utf8": check_utf8, "column_types": convert_schema},60 **json.loads(self.format.additional_reader_options),61 }62 def get_inferred_schema(self, file: Union[TextIO, BinaryIO]) -> Mapping[str, Any]:63 """64 https://arrow.apache.org/docs/python/generated/pyarrow.csv.open_csv.html65 This now uses multiprocessing in order to timeout the schema inference as it can hang.66 Since the hanging code is resistant to signal interrupts, threading/futures doesn't help so needed to multiprocess.67 https://issues.apache.org/jira/browse/ARROW-11853?page=com.atlassian.jira.plugin.system.issuetabpanels%3Aall-tabpanel68 """69 def infer_schema_process(70 file_sample: str, read_opts: dict, parse_opts: dict, convert_opts: dict71 ) -> Tuple[dict, Optional[Exception]]:72 """73 we need to reimport here to be functional on Windows systems since it doesn't have fork()74 https://docs.python.org/3.7/library/multiprocessing.html#contexts-and-start-methods75 This returns a tuple of (schema_dict, None OR Exception).76 If return[1] is not None and holds an exception we then raise this in the main process.77 This lets us propagate up any errors (that aren't timeouts) and raise correctly.78 """79 try:80 import tempfile81 import pyarrow as pa82 # writing our file_sample to a temporary file to then read in and schema infer as before83 with tempfile.TemporaryFile() as fp:84 fp.write(file_sample) # type: ignore[arg-type]85 fp.seek(0)86 streaming_reader = pa.csv.open_csv(87 fp, pa.csv.ReadOptions(**read_opts), pa.csv.ParseOptions(**parse_opts), pa.csv.ConvertOptions(**convert_opts)88 )89 schema_dict = {field.name: field.type for field in streaming_reader.schema}90 except Exception as e:91 # we pass the traceback up otherwise the main process won't know the exact method+line of error92 return (None, e)93 else:94 return (schema_dict, None)95 # boto3 objects can't be pickled (https://github.com/boto/boto3/issues/678)96 # and so we can't multiprocess with the actual fileobject on Windows systems97 # we're reading block_size*2 bytes here, which we can then pass in and infer schema from block_size bytes98 # the *2 is to give us a buffer as pyarrow figures out where lines actually end so it gets schema correct99 schema_dict = self._get_schema_dict(file, infer_schema_process)100 return self.json_schema_to_pyarrow_schema(schema_dict, reverse=True) # type: ignore[no-any-return]101 def _get_schema_dict(self, file: Union[TextIO, BinaryIO], infer_schema_process: Callable) -> Mapping[str, Any]:102 if not self.format.infer_datatypes:103 return self._get_schema_dict_without_inference(file)104 self.logger.debug("inferring schema")105 file_sample = file.read(self._read_options()["block_size"] * 2) # type: ignore[arg-type]106 return run_in_external_process(107 fn=infer_schema_process,108 timeout=4,109 max_timeout=60,110 logger=self.logger,111 args=[112 file_sample,113 self._read_options(),114 self._parse_options(),115 self._convert_options(),116 ],117 )118 # TODO Rename this here and in `_get_schema_dict`119 def _get_schema_dict_without_inference(self, file: Union[TextIO, BinaryIO]) -> Mapping[str, Any]:120 self.logger.debug("infer_datatypes is False, skipping infer_schema")121 delimiter = self.format.delimiter122 quote_char = self.format.quote_char123 reader = csv.reader([six.ensure_text(file.readline())], delimiter=delimiter, quotechar=quote_char)124 field_names = next(reader)125 return {field_name.strip(): pyarrow.string() for field_name in field_names}126 def stream_records(self, file: Union[TextIO, BinaryIO]) -> Iterator[Mapping[str, Any]]:127 """128 https://arrow.apache.org/docs/python/generated/pyarrow.csv.open_csv.html129 PyArrow returns lists of values for each column so we zip() these up into records which we then yield130 """131 streaming_reader = pa_csv.open_csv(132 file,133 pa.csv.ReadOptions(**self._read_options()),134 pa.csv.ParseOptions(**self._parse_options()),135 pa.csv.ConvertOptions(**self._convert_options(self._master_schema)),136 )137 still_reading = True138 while still_reading:139 try:140 batch = streaming_reader.read_next_batch()141 except StopIteration:142 still_reading = False143 else:144 batch_dict = batch.to_pydict()145 batch_columns = [col_info.name for col_info in batch.schema]146 # this gives us a list of lists where each nested list holds ordered values for a single column147 # e.g. [ [1,2,3], ["a", "b", "c"], [True, True, False] ]...
LossConfigurator.py
Source:LossConfigurator.py
...17 if isinstance(config, dict):18 step_separation = np.any([(step_name in config) for step_name in step_names])19 if step_separation:20 losses = {}21 global_options = self._read_options(config)22 for step_name in step_names:23 if step_name in config:24 step_config = config[step_name]25 losses.update({step_name: self._read_loss_config(step_config, global_options)})26 # append objective configuration to other steps27 if "training" in losses:28 if len(losses["training"]) > 0:29 objective = losses["training"][0]30 for step_name in losses.keys():31 if step_name != "training":32 step_losses = losses[step_name]33 loss_names = [loss.name for loss in step_losses]34 if objective.name not in loss_names:35 step_losses.append(objective)36 else:37 losses = {"training": [self._read_loss_config(config)]}38 elif isinstance(config, (list, tuple)):39 losses = {"training": self._read_loss_config(config)}40 else:41 raise Exception('[ERROR] Unknown configuration format.')42 for step_name in step_names:43 if step_name in losses:44 losses.update({45 step_name: LossCollection(*losses[step_name])46 })47 return losses48 def _read_loss_config(self, config, global_options=None):49 loss_list = []50 if isinstance(config, dict):51 if "losses" in config:52 local_options = self._read_options(config, global_options)53 loss_config = config["losses"]54 loss_list = self._read_loss_config(loss_config, local_options)55 if "type" in config:56 loss_type = LossType(config["type"])57 local_options = self._read_options(config, global_options)58 kwargs = {59 kw: config[kw]60 for kw in config.keys()61 if kw not in local_options and kw != "losses"62 }63 loss_constructor = __loss_functions__[loss_type]64 loss_config = global_options.copy() if global_options is not None else {}65 loss_config.update(config)66 if loss_type != LossType.WEIGHTED:67 loss_list = [loss_constructor(**{**loss_config, **kwargs})]68 else:69 raise NotImplementedError()70 # assert len(loss_list) > 071 # loss_list = [loss_constructor(*loss_list, **{**loss_config, **kwargs})]72 elif isinstance(config, (list, tuple)):73 for loss_config in config:74 loss_list += self._read_loss_config(loss_config, global_options)75 elif isinstance(config, str):76 loss_type = LossType(config.upper())77 assert loss_type != LossType.WEIGHTED78 loss_constructor = __loss_functions__[loss_type]79 kwargs = global_options.copy() if global_options is not None else {}80 loss_list = [loss_constructor(**kwargs)]81 else:82 raise Exception('[ERROR] Unknown configuration format.')83 return loss_list84 @staticmethod85 def _read_options(config, global_options=None):86 option_names = ['use_scalings', 'use_mask', 'batch_reduction', 'spatial_reduction']87 if global_options is None:88 option_dict = {kw: None for kw in option_names}89 else:90 option_dict = {kw: global_options[kw] for kw in option_names}91 for kw in option_names:92 if kw in config:93 option_dict.update({kw: config[kw]})...
Learn to execute automation testing from scratch with LambdaTest Learning Hub. Right from setting up the prerequisites to run your first automation test, to following best practices and diving deeper into advanced test scenarios. LambdaTest Learning Hubs compile a list of step-by-step guides to help you be proficient with different test automation frameworks i.e. Selenium, Cypress, TestNG etc.
You could also refer to video tutorials over LambdaTest YouTube channel to get step by step demonstration from industry experts.
Get 100 minutes of automation test minutes FREE!!