## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importglobimportosimportstructimportsysimportunittestimportdifflibimportfunctoolsimportmathfromdecimalimportDecimalfromtimeimporttime,sleepfromtypingimport(Any,Optional,Union,Dict,List,Callable,)fromitertoolsimportzip_longesthave_scipy=Falsehave_numpy=Falsetry:importscipy# noqa: F401have_scipy=TrueexceptImportError:# No SciPy, but that's okay, we'll skip those testspasstry:importnumpyasnp# noqa: F401have_numpy=TrueexceptImportError:# No NumPy, but that's okay, we'll skip those testspassfrompysparkimportSparkConffrompyspark.errorsimportPySparkAssertionError,PySparkException,PySparkTypeErrorfrompyspark.errors.exceptions.capturedimportCapturedExceptionfrompyspark.errors.exceptions.baseimportQueryContextTypefrompyspark.find_spark_homeimport_find_spark_homefrompyspark.sql.dataframeimportDataFramefrompyspark.sqlimportRowfrompyspark.sql.typesimportStructType,StructFieldfrompyspark.sql.functionsimportcol,when__all__=["assertDataFrameEqual","assertSchemaEqual"]SPARK_HOME=_find_spark_home()defread_int(b):returnstruct.unpack("!i",b)[0]defwrite_int(i):returnstruct.pack("!i",i)defeventually(timeout=30.0,catch_assertions=False,):""" Wait a given amount of time for a condition to pass, else fail with an error. This is a helper utility for PySpark tests. Parameters ---------- condition : function Function that checks for termination conditions. condition() can return: - True or None: Conditions met. Return without error. - other value: Conditions not met yet. Continue. Upon timeout, include last such value in error message. Note that this method may be called at any time during streaming execution (e.g., even before any results have been created). timeout : int Number of seconds to wait. Default 30 seconds. catch_assertions : bool If False (default), do not catch AssertionErrors. If True, catch AssertionErrors; continue, but save error to throw upon timeout. """asserttimeout>0assertisinstance(catch_assertions,bool)defdecorator(condition:Callable)->Callable:assertisinstance(condition,Callable)@functools.wraps(condition)defwrapper(*args:Any,**kwargs:Any)->Any:start_time=time()lastValue=NonenumTries=0whiletime()-start_time<timeout:numTries+=1ifcatch_assertions:try:lastValue=condition(*args,**kwargs)exceptAssertionErrorase:lastValue=eelse:lastValue=condition(*args,**kwargs)iflastValueisTrueorlastValueisNone:returnprint(f"\nAttempt #{numTries} failed!\n{lastValue}")sleep(0.01)ifisinstance(lastValue,AssertionError):raiselastValueelse:raiseAssertionError("Test failed due to timeout after %g sec, with last condition returning: %s"%(timeout,lastValue))returnwrapperreturndecoratorclassQuietTest:def__init__(self,sc):self.log4j=sc._jvm.org.apache.log4jdef__enter__(self):self.old_level=self.log4j.LogManager.getRootLogger().getLevel()self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL)def__exit__(self,exc_type,exc_val,exc_tb):self.log4j.LogManager.getRootLogger().setLevel(self.old_level)classPySparkTestCase(unittest.TestCase):defsetUp(self):frompysparkimportSparkContextself._old_sys_path=list(sys.path)class_name=self.__class__.__name__self.sc=SparkContext("local[4]",class_name)deftearDown(self):self.sc.stop()sys.path=self._old_sys_pathclassReusedPySparkTestCase(unittest.TestCase):@classmethoddefconf(cls):""" Override this in subclasses to supply a more specific conf """returnSparkConf()@classmethoddefsetUpClass(cls):frompysparkimportSparkContextcls.sc=SparkContext("local[4]",cls.__name__,conf=cls.conf())@classmethoddeftearDownClass(cls):cls.sc.stop()deftest_assert_classic_mode(self):frompyspark.sqlimportis_remoteself.assertFalse(is_remote())defquiet(self):frompyspark.testing.utilsimportQuietTestreturnQuietTest(self.sc)classByteArrayOutput:def__init__(self):self.buffer=bytearray()defwrite(self,b):self.buffer+=bdefclose(self):passdefsearch_jar(project_relative_path,sbt_jar_name_prefix,mvn_jar_name_prefix):# Note that 'sbt_jar_name_prefix' and 'mvn_jar_name_prefix' are used since the prefix can# vary for SBT or Maven specifically. See also SPARK-26856project_full_path=os.path.join(SPARK_HOME,project_relative_path)# We should ignore the following jarsignored_jar_suffixes=("javadoc.jar","sources.jar","test-sources.jar","tests.jar")# Search jar in the project dir using the jar name_prefix for both sbt build and maven# build because the artifact jars are in different directories.sbt_build=glob.glob(os.path.join(project_full_path,"target/scala-*/%s*.jar"%sbt_jar_name_prefix))maven_build=glob.glob(os.path.join(project_full_path,"target/%s*.jar"%mvn_jar_name_prefix))jar_paths=sbt_build+maven_buildjars=[jarforjarinjar_pathsifnotjar.endswith(ignored_jar_suffixes)]ifnotjars:returnNoneeliflen(jars)>1:raiseRuntimeError("Found multiple JARs: %s; please remove all but one"%(", ".join(jars)))else:returnjars[0]def_terminal_color_support():try:# determine if environment supports colorscript="$(test $(tput colors)) && $(test $(tput colors) -ge 8) && echo true || echo false"returnos.popen(script).read()exceptException:returnFalsedef_context_diff(actual:List[str],expected:List[str],n:int=3):""" Modified from difflib context_diff API, see original code here: https://github.com/python/cpython/blob/main/Lib/difflib.py#L1180 """defred(s:str)->str:red_color="\033[31m"no_color="\033[0m"returnred_color+str(s)+no_colorprefix=dict(insert="+ ",delete="- ",replace="! ",equal=" ")forgroupindifflib.SequenceMatcher(None,actual,expected).get_grouped_opcodes(n):yield"*** actual ***"ifany(tagin{"replace","delete"}fortag,_,_,_,_ingroup):fortag,i1,i2,_,_ingroup:forlineinactual[i1:i2]:iftag!="equal"and_terminal_color_support():yieldred(prefix[tag]+str(line))else:yieldprefix[tag]+str(line)yield"\n"yield"*** expected ***"ifany(tagin{"replace","insert"}fortag,_,_,_,_ingroup):fortag,_,_,j1,j2ingroup:forlineinexpected[j1:j2]:iftag!="equal"and_terminal_color_support():yieldred(prefix[tag]+str(line))else:yieldprefix[tag]+str(line)classPySparkErrorTestUtils:""" This util provide functions to accurate and consistent error testing based on PySpark error classes. """defcheck_error(self,exception:PySparkException,errorClass:str,messageParameters:Optional[Dict[str,str]]=None,query_context_type:Optional[QueryContextType]=None,fragment:Optional[str]=None,):query_context=exception.getQueryContext()assertbool(query_context)==(query_context_typeisnotNone),("`query_context_type` is required when QueryContext exists. "f"QueryContext: {query_context}.")# Test if given error is an instance of PySparkException.self.assertIsInstance(exception,PySparkException,f"checkError requires 'PySparkException', got '{exception.__class__.__name__}'.",)# Test error classexpected=errorClassactual=exception.getErrorClass()self.assertEqual(expected,actual,f"Expected error class was '{expected}', got '{actual}'.")# Test message parametersexpected=messageParametersactual=exception.getMessageParameters()self.assertEqual(expected,actual,f"Expected message parameters was '{expected}', got '{actual}'")# Test query contextifquery_context:expected=query_context_typeactual_contexts=exception.getQueryContext()foractual_contextinactual_contexts:actual=actual_context.contextType()self.assertEqual(expected,actual,f"Expected QueryContext was '{expected}', got '{actual}'")ifactual==QueryContextType.DataFrame:assert(fragmentisnotNone),"`fragment` is required when QueryContextType is DataFrame."expected=fragmentactual=actual_context.fragment()self.assertEqual(expected,actual,f"Expected PySpark fragment was '{expected}', got '{actual}'",)
[docs]defassertSchemaEqual(actual:StructType,expected:StructType,ignoreNullable:bool=True,ignoreColumnOrder:bool=False,ignoreColumnName:bool=False,):r""" A util function to assert equality between DataFrame schemas `actual` and `expected`. .. versionadded:: 3.5.0 Parameters ---------- actual : StructType The DataFrame schema that is being compared or tested. expected : StructType The expected schema, for comparison with the actual schema. ignoreNullable : bool, default True Specifies whether a column’s nullable property is included when checking for schema equality. When set to `True` (default), the nullable property of the columns being compared is not taken into account and the columns will be considered equal even if they have different nullable settings. When set to `False`, columns are considered equal only if they have the same nullable setting. .. versionadded:: 4.0.0 ignoreColumnOrder : bool, default False Specifies whether to compare columns in the order they appear in the DataFrame or by column name. If set to `False` (default), columns are compared in the order they appear in the DataFrames. When set to `True`, a column in the expected DataFrame is compared to the column with the same name in the actual DataFrame. .. versionadded:: 4.0.0 ignoreColumnName : bool, default False Specifies whether to fail the initial schema equality check if the column names in the two DataFrames are different. When set to `False` (default), column names are checked and the function fails if they are different. When set to `True`, the function will succeed even if column names are different. Column data types are compared for columns in the order they appear in the DataFrames. .. versionadded:: 4.0.0 Notes ----- When assertSchemaEqual fails, the error message uses the Python `difflib` library to display a diff log of the `actual` and `expected` schemas. Examples -------- >>> from pyspark.sql.types import StructType, StructField, ArrayType, IntegerType, DoubleType >>> s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)]) >>> s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)]) >>> assertSchemaEqual(s1, s2) # pass, schemas are identical Different schemas with `ignoreNullable=False` would fail. >>> s3 = StructType([StructField("names", ArrayType(DoubleType(), True), False)]) >>> assertSchemaEqual(s1, s3, ignoreNullable=False) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match. --- actual +++ expected - StructType([StructField('names', ArrayType(DoubleType(), True), True)]) ? ^^^ + StructType([StructField('names', ArrayType(DoubleType(), True), False)]) ? ^^^^ >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "number"]) >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 5000)], schema=["id", "amount"]) >>> assertSchemaEqual(df1.schema, df2.schema) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match. --- actual +++ expected - StructType([StructField('id', LongType(), True), StructField('number', LongType(), True)]) ? ^^ ^^^^^ + StructType([StructField('id', StringType(), True), StructField('amount', LongType(), True)]) ? ^^^^ ++++ ^ Compare two schemas ignoring the column order. >>> s1 = StructType( ... [StructField("a", IntegerType(), True), StructField("b", DoubleType(), True)] ... ) >>> s2 = StructType( ... [StructField("b", DoubleType(), True), StructField("a", IntegerType(), True)] ... ) >>> assertSchemaEqual(s1, s2, ignoreColumnOrder=True) Compare two schemas ignoring the column names. >>> s1 = StructType( ... [StructField("a", IntegerType(), True), StructField("c", DoubleType(), True)] ... ) >>> s2 = StructType( ... [StructField("b", IntegerType(), True), StructField("d", DoubleType(), True)] ... ) >>> assertSchemaEqual(s1, s2, ignoreColumnName=True) """ifnotisinstance(actual,StructType):raisePySparkTypeError(errorClass="NOT_STRUCT",messageParameters={"arg_name":"actual","arg_type":type(actual).__name__},)ifnotisinstance(expected,StructType):raisePySparkTypeError(errorClass="NOT_STRUCT",messageParameters={"arg_name":"expected","arg_type":type(expected).__name__},)defcompare_schemas_ignore_nullable(s1:StructType,s2:StructType):iflen(s1)!=len(s2):returnFalsezipped=zip_longest(s1,s2)forsf1,sf2inzipped:ifnotcompare_structfields_ignore_nullable(sf1,sf2):returnFalsereturnTruedefcompare_structfields_ignore_nullable(actualSF:StructField,expectedSF:StructField):ifactualSFisNoneandexpectedSFisNone:returnTrueelifactualSFisNoneorexpectedSFisNone:returnFalseifactualSF.name!=expectedSF.name:returnFalseelse:returncompare_datatypes_ignore_nullable(actualSF.dataType,expectedSF.dataType)defcompare_datatypes_ignore_nullable(dt1:Any,dt2:Any):# checks datatype equality, using recursion to ignore nullableifdt1.typeName()==dt2.typeName():ifdt1.typeName()=="array":returncompare_datatypes_ignore_nullable(dt1.elementType,dt2.elementType)elifdt1.typeName()=="struct":returncompare_schemas_ignore_nullable(dt1,dt2)else:returnTrueelse:returnFalseifignoreColumnOrder:actual=StructType(sorted(actual,key=lambdax:x.name))expected=StructType(sorted(expected,key=lambdax:x.name))ifignoreColumnName:actual=StructType([StructField(str(i),field.dataType,field.nullable)fori,fieldinenumerate(actual)])expected=StructType([StructField(str(i),field.dataType,field.nullable)fori,fieldinenumerate(expected)])if(ignoreNullableandnotcompare_schemas_ignore_nullable(actual,expected))or(notignoreNullableandactual!=expected):generated_diff=difflib.ndiff(str(actual).splitlines(),str(expected).splitlines())error_msg="\n".join(generated_diff)raisePySparkAssertionError(errorClass="DIFFERENT_SCHEMA",messageParameters={"error_msg":error_msg},)
[docs]defassertDataFrameEqual(actual:Union[DataFrame,"pandas.DataFrame","pyspark.pandas.DataFrame",List[Row]],expected:Union[DataFrame,"pandas.DataFrame","pyspark.pandas.DataFrame",List[Row]],checkRowOrder:bool=False,rtol:float=1e-5,atol:float=1e-8,ignoreNullable:bool=True,ignoreColumnOrder:bool=False,ignoreColumnName:bool=False,ignoreColumnType:bool=False,maxErrors:Optional[int]=None,showOnlyDiff:bool=False,includeDiffRows=False,):r""" A util function to assert equality between `actual` and `expected` (DataFrames or lists of Rows), with optional parameters `checkRowOrder`, `rtol`, and `atol`. Supports Spark, Spark Connect, pandas, and pandas-on-Spark DataFrames. For more information about pandas-on-Spark DataFrame equality, see the docs for `assertPandasOnSparkEqual`. .. versionadded:: 3.5.0 Parameters ---------- actual : DataFrame (Spark, Spark Connect, pandas, or pandas-on-Spark) or list of Rows The DataFrame that is being compared or tested. expected : DataFrame (Spark, Spark Connect, pandas, or pandas-on-Spark) or list of Rows The expected result of the operation, for comparison with the actual result. checkRowOrder : bool, optional A flag indicating whether the order of rows should be considered in the comparison. If set to `False` (default), the row order is not taken into account. If set to `True`, the order of rows is important and will be checked during comparison. (See Notes) rtol : float, optional The relative tolerance, used in asserting approximate equality for float values in actual and expected. Set to 1e-5 by default. (See Notes) atol : float, optional The absolute tolerance, used in asserting approximate equality for float values in actual and expected. Set to 1e-8 by default. (See Notes) ignoreNullable : bool, default True Specifies whether a column’s nullable property is included when checking for schema equality. When set to `True` (default), the nullable property of the columns being compared is not taken into account and the columns will be considered equal even if they have different nullable settings. When set to `False`, columns are considered equal only if they have the same nullable setting. .. versionadded:: 4.0.0 ignoreColumnOrder : bool, default False Specifies whether to compare columns in the order they appear in the DataFrame or by column name. If set to `False` (default), columns are compared in the order they appear in the DataFrames. When set to `True`, a column in the expected DataFrame is compared to the column with the same name in the actual DataFrame. .. versionadded:: 4.0.0 ignoreColumnName : bool, default False Specifies whether to fail the initial schema equality check if the column names in the two DataFrames are different. When set to `False` (default), column names are checked and the function fails if they are different. When set to `True`, the function will succeed even if column names are different. Column data types are compared for columns in the order they appear in the DataFrames. .. versionadded:: 4.0.0 ignoreColumnType : bool, default False Specifies whether to ignore the data type of the columns when comparing. When set to `False` (default), column data types are checked and the function fails if they are different. When set to `True`, the schema equality check will succeed even if column data types are different and the function will attempt to compare rows. .. versionadded:: 4.0.0 maxErrors : bool, optional The maximum number of row comparison failures to encounter before returning. When this number of row comparisons have failed, the function returns independent of how many rows have been compared. Set to None by default which means compare all rows independent of number of failures. .. versionadded:: 4.0.0 showOnlyDiff : bool, default False If set to `True`, the error message will only include rows that are different. If set to `False` (default), the error message will include all rows (when there is at least one row that is different). .. versionadded:: 4.0.0 includeDiffRows: bool, False If set to `True`, the unequal rows are included in PySparkAssertionError for further debugging. If set to `False` (default), the unequal rows are not returned as a data set. .. versionadded:: 4.0.0 Notes ----- When `assertDataFrameEqual` fails, the error message uses the Python `difflib` library to display a diff log of each row that differs in `actual` and `expected`. For `checkRowOrder`, note that PySpark DataFrame ordering is non-deterministic, unless explicitly sorted. Note that schema equality is checked only when `expected` is a DataFrame (not a list of Rows). For DataFrames with float/decimal values, assertDataFrame asserts approximate equality. Two float/decimal values a and b are approximately equal if the following equation is True: ``absolute(a - b) <= (atol + rtol * absolute(b))``. `ignoreColumnOrder` cannot be set to `True` if `ignoreColumnNames` is also set to `True`. `ignoreColumnNames` cannot be set to `True` if `ignoreColumnOrder` is also set to `True`. Examples -------- >>> df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"]) >>> assertDataFrameEqual(df1, df2) # pass, DataFrames are identical >>> df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"]) >>> assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "amount"]) >>> list_of_rows = [Row(1, 1000), Row(2, 3000)] >>> assertDataFrameEqual(df1, list_of_rows) # pass, actual and expected data are equal >>> import pyspark.pandas as ps # doctest: +SKIP >>> df1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) # doctest: +SKIP >>> df2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}) # doctest: +SKIP >>> # pass, pandas-on-Spark DataFrames are equal >>> assertDataFrameEqual(df1, df2) # doctest: +SKIP >>> df1 = spark.createDataFrame( ... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame( ... data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], schema=["id", "amount"]) >>> assertDataFrameEqual(df1, df2) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 % ) *** actual *** ! Row(id='1', amount=1000.0) Row(id='2', amount=3000.0) ! Row(id='3', amount=2000.0) *** expected *** ! Row(id='1', amount=1001.0) Row(id='2', amount=3000.0) ! Row(id='3', amount=2003.0) Example for ignoreNullable >>> from pyspark.sql.types import StructType, StructField, StringType, LongType >>> df1_nullable = spark.createDataFrame( ... data=[(1000, "1"), (5000, "2")], ... schema=StructType( ... [StructField("amount", LongType(), True), StructField("id", StringType(), True)] ... ) ... ) >>> df2_nullable = spark.createDataFrame( ... data=[(1000, "1"), (5000, "2")], ... schema=StructType( ... [StructField("amount", LongType(), True), StructField("id", StringType(), False)] ... ) ... ) >>> assertDataFrameEqual(df1_nullable, df2_nullable, ignoreNullable=True) # pass >>> assertDataFrameEqual( ... df1_nullable, df2_nullable, ignoreNullable=False ... ) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match. --- actual +++ expected - StructType([StructField('amount', LongType(), True), StructField('id', StringType(), True)]) ? ^^^ + StructType([StructField('amount', LongType(), True), StructField('id', StringType(), False)]) ? ^^^^ Example for ignoreColumnOrder >>> df1_col_order = spark.createDataFrame( ... data=[(1000, "1"), (5000, "2")], schema=["amount", "id"] ... ) >>> df2_col_order = spark.createDataFrame( ... data=[("1", 1000), ("2", 5000)], schema=["id", "amount"] ... ) >>> assertDataFrameEqual(df1_col_order, df2_col_order, ignoreColumnOrder=True) Example for ignoreColumnName >>> df1_col_names = spark.createDataFrame( ... data=[(1000, "1"), (5000, "2")], schema=["amount", "identity"] ... ) >>> df2_col_names = spark.createDataFrame( ... data=[(1000, "1"), (5000, "2")], schema=["amount", "id"] ... ) >>> assertDataFrameEqual(df1_col_names, df2_col_names, ignoreColumnName=True) Example for ignoreColumnType >>> df1_col_types = spark.createDataFrame( ... data=[(1000, "1"), (5000, "2")], schema=["amount", "id"] ... ) >>> df2_col_types = spark.createDataFrame( ... data=[(1000.0, "1"), (5000.0, "2")], schema=["amount", "id"] ... ) >>> assertDataFrameEqual(df1_col_types, df2_col_types, ignoreColumnType=True) Example for maxErrors (will only report the first mismatching row) >>> df1 = spark.createDataFrame([(1, "A"), (2, "B"), (3, "C")]) >>> df2 = spark.createDataFrame([(1, "A"), (2, "X"), (3, "Y")]) >>> assertDataFrameEqual(df1, df2, maxErrors=1) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 33.33333 % ) *** actual *** Row(_1=1, _2='A') ! Row(_1=2, _2='B') *** expected *** Row(_1=1, _2='A') ! Row(_1=2, _2='X') Example for showOnlyDiff (will only report the mismatching rows) >>> df1 = spark.createDataFrame([(1, "A"), (2, "B"), (3, "C")]) >>> df2 = spark.createDataFrame([(1, "A"), (2, "X"), (3, "Y")]) >>> assertDataFrameEqual(df1, df2, showOnlyDiff=True) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 % ) *** actual *** ! Row(_1=2, _2='B') ! Row(_1=3, _2='C') *** expected *** ! Row(_1=2, _2='X') ! Row(_1=3, _2='Y') The `includeDiffRows` parameter can be used to include the rows that did not match in the PySparkAssertionError. This can be useful for debugging or further analysis. >>> df1 = spark.createDataFrame( ... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame( ... data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], schema=["id", "amount"]) >>> try: ... assertDataFrameEqual(df1, df2, includeDiffRows=True) ... except PySparkAssertionError as e: ... spark.createDataFrame(e.data).show() # doctest: +NORMALIZE_WHITESPACE +-----------+-----------+ | _1| _2| +-----------+-----------+ |{1, 1000.0}|{1, 1001.0}| |{3, 2000.0}|{3, 2003.0}| +-----------+-----------+ """ifactualisNoneandexpectedisNone:returnTrueelifactualisNone:raisePySparkAssertionError(errorClass="INVALID_TYPE_DF_EQUALITY_ARG",messageParameters={"expected_type":"Union[DataFrame, ps.DataFrame, List[Row]]","arg_name":"actual","actual_type":None,},)elifexpectedisNone:raisePySparkAssertionError(errorClass="INVALID_TYPE_DF_EQUALITY_ARG",messageParameters={"expected_type":"Union[DataFrame, ps.DataFrame, List[Row]]","arg_name":"expected","actual_type":None,},)has_pandas=Falsetry:# If pandas dependencies are available, allow pandas or pandas-on-Spark DataFrameimportpandasaspdhas_pandas=TrueexceptImportError:# no pandas, so we won't call pandasutils functionspasshas_arrow=Falsetry:importpyarrowhas_arrow=TrueexceptImportError:passifhas_pandasandhas_arrow:importpyspark.pandasaspsfrompyspark.testing.pandasutilsimportPandasOnSparkTestUtilsif(isinstance(actual,pd.DataFrame)orisinstance(expected,pd.DataFrame)orisinstance(actual,ps.DataFrame)orisinstance(expected,ps.DataFrame)):# handle pandas DataFrames# assert approximate equality for float datareturnPandasOnSparkTestUtils().assert_eq(actual,expected,almost=True,rtol=rtol,atol=atol,check_row_order=checkRowOrder)ifnotisinstance(actual,(DataFrame,list)):raisePySparkAssertionError(errorClass="INVALID_TYPE_DF_EQUALITY_ARG",messageParameters={"expected_type":"Union[DataFrame, ps.DataFrame, List[Row]]","arg_name":"actual","actual_type":type(actual),},)elifnotisinstance(expected,(DataFrame,list)):raisePySparkAssertionError(errorClass="INVALID_TYPE_DF_EQUALITY_ARG",messageParameters={"expected_type":"Union[DataFrame, ps.DataFrame, List[Row]]","arg_name":"expected","actual_type":type(expected),},)ifignoreColumnOrder:actual=actual.select(*sorted(actual.columns))expected=expected.select(*sorted(expected.columns))defrename_dataframe_columns(df:DataFrame)->DataFrame:"""Rename DataFrame columns to sequential numbers for comparison"""renamed_columns=[str(i)foriinrange(len(df.columns))]returndf.toDF(*renamed_columns)ifignoreColumnName:actual=rename_dataframe_columns(actual)expected=rename_dataframe_columns(expected)defcast_columns_to_string(df:DataFrame)->DataFrame:"""Cast all DataFrame columns to string for comparison"""forcol_nameindf.columns:# Add logic to remove trailing .0 for float columns that are whole numbersdf=df.withColumn(col_name,when((col(col_name).cast("float").isNotNull())&(col(col_name).cast("float")==col(col_name).cast("int")),col(col_name).cast("int").cast("string"),).otherwise(col(col_name).cast("string")),)returndfifignoreColumnType:actual=cast_columns_to_string(actual)expected=cast_columns_to_string(expected)defcompare_rows(r1:Row,r2:Row):defcompare_vals(val1,val2):ifisinstance(val1,list)andisinstance(val2,list):returnlen(val1)==len(val2)andall(compare_vals(x,y)forx,yinzip(val1,val2))elifisinstance(val1,Row)andisinstance(val2,Row):returnall(compare_vals(x,y)forx,yinzip(val1,val2))elifisinstance(val1,dict)andisinstance(val2,dict):return(len(val1.keys())==len(val2.keys())andval1.keys()==val2.keys()andall(compare_vals(val1[k],val2[k])forkinval1.keys()))elifisinstance(val1,float)andisinstance(val2,float):ifabs(val1-val2)>(atol+rtol*abs(val2)):returnFalseelifisinstance(val1,Decimal)andisinstance(val2,Decimal):ifabs(val1-val2)>(Decimal(atol)+Decimal(rtol)*abs(val2)):returnFalseelse:ifval1!=val2:returnFalsereturnTrueifr1isNoneandr2isNone:returnTrueelifr1isNoneorr2isNone:returnFalsereturncompare_vals(r1,r2)defassert_rows_equal(rows1:List[Row],rows2:List[Row],maxErrors:int=None,showOnlyDiff:bool=False):zipped=list(zip_longest(rows1,rows2))diff_rows_cnt=0diff_rows=[]has_diff_rows=Falserows_str1=""rows_str2=""# count different rowsforr1,r2inzipped:ifnotcompare_rows(r1,r2):diff_rows_cnt+=1has_diff_rows=TrueifincludeDiffRows:diff_rows.append((r1,r2))rows_str1+=str(r1)+"\n"rows_str2+=str(r2)+"\n"ifmaxErrorsisnotNoneanddiff_rows_cnt>=maxErrors:breakelifnotshowOnlyDiff:rows_str1+=str(r1)+"\n"rows_str2+=str(r2)+"\n"generated_diff=_context_diff(actual=rows_str1.splitlines(),expected=rows_str2.splitlines(),n=len(zipped))ifhas_diff_rows:error_msg="Results do not match: "percent_diff=(diff_rows_cnt/len(zipped))*100error_msg+="( %.5f%% )"%percent_differror_msg+="\n"+"\n".join(generated_diff)data=diff_rowsifincludeDiffRowselseNoneraisePySparkAssertionError(errorClass="DIFFERENT_ROWS",messageParameters={"error_msg":error_msg},data=data)# only compare schema if expected is not a Listifnotisinstance(actual,list)andnotisinstance(expected,list):ifignoreNullable:assertSchemaEqual(actual.schema,expected.schema)elifactual.schema!=expected.schema:generated_diff=difflib.ndiff(str(actual.schema).splitlines(),str(expected.schema).splitlines())error_msg="\n".join(generated_diff)raisePySparkAssertionError(errorClass="DIFFERENT_SCHEMA",messageParameters={"error_msg":error_msg},)ifnotisinstance(actual,list):ifactual.isStreaming:raisePySparkAssertionError(errorClass="UNSUPPORTED_OPERATION",messageParameters={"operation":"assertDataFrameEqual on streaming DataFrame"},)actual_list=actual.collect()else:actual_list=actualifnotisinstance(expected,list):ifexpected.isStreaming:raisePySparkAssertionError(errorClass="UNSUPPORTED_OPERATION",messageParameters={"operation":"assertDataFrameEqual on streaming DataFrame"},)expected_list=expected.collect()else:expected_list=expectedifnotcheckRowOrder:# rename duplicate columns for sortingactual_list=sorted(actual_list,key=lambdax:str(x))expected_list=sorted(expected_list,key=lambdax:str(x))assert_rows_equal(actual_list,expected_list,maxErrors=maxErrors,showOnlyDiff=showOnlyDiff)