15721这一章没什么好说的,不再贴课程内容了。codegen和simd在工业界一般只会选一种实现。比如phothon之前用codegen,然后改成了向量化引擎。一般gen的都是weld IR/LLVM IR/当前语言,gen成C++的也要检查是不是有本地预编译版本,要不没法用。因为clickhouse没有codegen,这节课就拿我比较熟悉的spark的tungsten来当例子,tungsten会gen成scala,然后拿janino动态编译。
tungsten主要有两个特色:一个是codegen,另一个是in-heap memory的管理。本文顺便把它的内存管理也分析一下。在jvm堆内自由分配内存,不被free,不受gc影响,还是挺有意思的。
WASG
手写代码的生成过程分为两个步骤:
- 从父节点到子节点,递归调用 doProduce,生成框架
- 从子节点到父节点,递归调用 doConsume,向框架填充每一个操作符的运算逻辑
首先,在 Stage 顶端节点也就是 Project 之上,添加 WholeStageCodeGen 节点。WholeStageCodeGen 节点通过调用 doExecute 来触发整个代码生成过程的计算。doExecute 会递归调用子节点的 doProduce 函数,直到遇到 Shuffle Boundary 为止。这里,Shuffle Boundary 指的是 Shuffle 边界,要么是数据源,要么是上一个 Stage 的输出。在叶子节点(也就是 Scan)调用的 doProduce 函数会先把手写代码的框架生成出来。
override def doExecute(): RDD[InternalRow] = { // 下面这一行将会调用子类的produce完成上述过程。 val (ctx, cleanedSource) = doCodeGen() // try to compile and fallback if it failed // 调用janino完成动态编译过程 val (_, compiledCodeStats) = try { CodeGenerator.compile(cleanedSource) } catch { case NonFatal(_) if !Utils.isTesting && conf.codegenFallback => // We should already saw the error message logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString") return child.execute() } // Check if compiled code has a too large function if (compiledCodeStats.maxMethodCodeSize > conf.hugeMethodLimit) { logInfo(s"Found too long generated codes and JIT optimization might not work: " + s"the bytecode size (${compiledCodeStats.maxMethodCodeSize}) is above the limit " + s"${conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " + s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString") return child.execute() } val references = ctx.references.toArray val durationMs = longMetric("pipelineTime") // Even though rdds is an RDD[InternalRow] it may actually be an RDD[ColumnarBatch] with // type erasure hiding that. This allows for the input to a code gen stage to be columnar, // but the output must be rows. val rdds = child.asInstanceOf[CodegenSupport].inputRDDs() assert(rdds.size <= 2, "Up to two input RDDs can be supported") if (rdds.length == 1) { rdds.head.mapPartitionsWithIndex { (index, iter) => val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(iter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext if (!v) durationMs += buffer.durationMs() v } override def next: InternalRow = buffer.next() } } } else { // Right now, we support up to two input RDDs. rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => Iterator((leftIter, rightIter)) // a small hack to obtain the correct partition index }.mapPartitionsWithIndex { (index, zippedIter) => val (leftIter, rightIter) = zippedIter.next() val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(leftIter, rightIter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext if (!v) durationMs += buffer.durationMs() v } override def next: InternalRow = buffer.next() } } } }
def doCodeGen(): (CodegenContext, CodeAndComment) = { val startTime = System.nanoTime() val ctx = new CodegenContext val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) // main next function. ctx.addNewFunction("processNext", s""" protected void processNext() throws java.io.IOException { ${code.trim} } """, inlineToOuterClass = true) val className = generatedClassName() val source = s""" public Object generate(Object[] references) { return new $className(references); } ${ctx.registerComment( s"""Codegened pipeline for stage (id=$codegenStageId) |${this.treeString.trim}""".stripMargin, "wsc_codegenPipeline")} ${ctx.registerComment(s"codegenStageId=$codegenStageId", "wsc_codegenStageId", true)} final class $className extends ${classOf[BufferedRowIterator].getName} { private Object[] references; private scala.collection.Iterator[] inputs; ${ctx.declareMutableStates()} public $className(Object[] references) { this.references = references; } public void init(int index, scala.collection.Iterator[] inputs) { partitionIndex = index; this.inputs = inputs; ${ctx.initMutableStates()} ${ctx.initPartition()} } ${ctx.emitExtraCode()} ${ctx.declareAddedFunctions()} } """.trim // try to compile, helpful for debug val cleanedSource = CodeFormatter.stripOverlappingComments( new CodeAndComment(CodeFormatter.stripExtraNewLines(source), ctx.getPlaceHolderToComments())) val duration = System.nanoTime() - startTime WholeStageCodegenExec.increaseCodeGenTime(duration) logDebug(s"\n${CodeFormatter.format(cleanedSource)}") (ctx, cleanedSource) }
然后,Scan 中的 doProduce 会反向递归调用每个父节点的 doConsume 函数。不同操作符在执行 doConsume 函数的过程中,会把关系表达式转化成 Java 代码,然后把这份代码像做“完形填空”一样,嵌入到刚刚的代码框架里。
doConsume代码不太好理解,我们以filter为例:
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val numOutput = metricTerm(ctx, "numOutputRows") val predicateCode = generatePredicateCode( ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes) // Reset the isNull to false for the not-null columns, then the followed operators could // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => if (notNullAttributes.contains(child.output(i).exprId)) { ev.isNull = FalseLiteral } ev } // Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;" s""" |do { | $predicateCode | $numOutput.add(1); | ${consume(ctx, resultVars)} |} while(false); """.stripMargin } protected def generatePredicateCode( ctx: CodegenContext, inputAttrs: Seq[Attribute], inputExprCode: Seq[ExprCode], outputAttrs: Seq[Attribute], notNullPreds: Seq[Expression], otherPreds: Seq[Expression], nonNullAttrExprIds: Seq[ExprId]): String = { /** * Generates code for `c`, using `in` for input attributes and `attrs` for nullability. */ def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = { val bound = BindReferences.bindReference(c, attrs) val evaluated = evaluateRequiredVariables(inputAttrs, in, c.references) // Generate the code for the predicate. val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx) val nullCheck = if (bound.nullable) { s"${ev.isNull} || " } else { s"" } s""" |$evaluated |${ev.code} |if (${nullCheck}!${ev.value}) continue; """.stripMargin } // To generate the predicates we will follow this algorithm. // For each predicate that is not IsNotNull, we will generate them one by one loading attributes // as necessary. For each of both attributes, if there is an IsNotNull predicate we will // generate that check *before* the predicate. After all of these predicates, we will generate // the remaining IsNotNull checks that were not part of other predicates. // This has the property of not doing redundant IsNotNull checks and taking better advantage of // short-circuiting, not loading attributes until they are needed. // This is very perf sensitive. // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val extraIsNotNullAttrs = mutable.Set[Attribute]() val generated = otherPreds.map { c => val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} if (idx != -1 && !generatedIsNotNullChecks(idx)) { generatedIsNotNullChecks(idx) = true // Use the child's output. The nullability is what the child produced. genPredicate(notNullPreds(idx), inputExprCode, inputAttrs) } else if (nonNullAttrExprIds.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) { extraIsNotNullAttrs += r genPredicate(IsNotNull(r), inputExprCode, inputAttrs) } else { "" } }.mkString("\n").trim // Here we use *this* operator's output with this output's nullability since we already // enforced them with the IsNotNull checks above. s""" |$nullChecks |${genPredicate(c, inputExprCode, outputAttrs)} """.stripMargin.trim }.mkString("\n") val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) => if (!generatedIsNotNullChecks(idx)) { genPredicate(c, inputExprCode, inputAttrs) } else { "" } }.mkString("\n") s""" |$generated |$nullChecks """.stripMargin } }
这个地方先裁剪再判断,首先对涉及到谓词的is not null生成判断,之后进行裁剪,然后对裁剪后的列没有覆盖到is not null的再做一次is not null。这里的性能比较关键。
对于以下sql:
SELECT department, AVG(salary) AS avg_salary FROM employee GROUP BY department HAVING AVG(salary) > 60000
生成效果如下:
generated:
boolean filter_value_2 = !hashAgg_isNull_11; if (!filter_value_2) continue; boolean filter_value_3 = false; filter_value_3 = org.apache.spark.sql.catalyst.util.SQLOrderingUtil.compareDoubles(hashAgg_value_11, 60000.0D) > 0; if (!filter_value_3) continue;
如果加上一句where salary IS NOT NULL,那么在hashAgg之前,还会插入一段null的判断:
boolean rdd_isNull_3 = rdd_row_0.isNullAt(3); double rdd_value_3 = rdd_isNull_3 ? -1.0 : (rdd_row_0.getDouble(3)); boolean filter_value_2 = !rdd_isNull_3; if (!filter_value_2) continue;
内存管理
tungsten memory management
这里的idea很简单,重构对象模型但是不改变gc逻辑,于是tungsten抽象出了page table,来存放大量java native object,page table地址还是由jvm进行管理,拿到地址后在jvm堆内查找。
spark-core
在看spark-unsafe中的tungsten分配器之前, 我们先看下spark-core中的内存管理模块,
我们可以看到MemoryManager中的分配器已经默认换成了tungsten
/** * Allocates memory for use by Unsafe/Tungsten code. */ private[memory] final val tungstenMemoryAllocator: MemoryAllocator = { tungstenMemoryMode match { case MemoryMode.ON_HEAP => MemoryAllocator.HEAP case MemoryMode.OFF_HEAP => MemoryAllocator.UNSAFE } }
MemoryManager就是用来管理Execution和Storage之间内存分配的类。
Execution和Storage都有堆内和堆外内存,使用内存池的方式由MemoryManager进行管理。
@GuardedBy("this") protected val onHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.ON_HEAP) @GuardedBy("this") protected val offHeapStorageMemoryPool = new StorageMemoryPool(this, MemoryMode.OFF_HEAP) @GuardedBy("this") protected val onHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.ON_HEAP) @GuardedBy("this") protected val offHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.OFF_HEAP)
对于tungsten的实际调用在TaskMMemoryManager中:
// 调用ExecutorMemoryManager进行内存分配,分配得到一个内存页,将其添加到 // page table中,用于内存地址映射 /** * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is * intended for allocating large blocks of memory that will be shared between operators. */ public MemoryBlock allocatePage(long size) { if (size > MAXIMUM_PAGE_SIZE_BYTES) { throw new IllegalArgumentException( "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } final int pageNumber; synchronized (this) { // allocatedPages是一个bitmap // PAGE_TABLE_SIZE是两个内存页 8KB pageNumber = allocatedPages.nextClearBit(0); if (pageNumber >= PAGE_TABLE_SIZE) { throw new IllegalStateException( "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); } allocatedPages.set(pageNumber); } try { page = memoryManager.tungstenMemoryAllocator().allocate(acquired); } catch (OutOfMemoryError e) { // 继续清理直到满足需要 logger.warn("Failed to allocate a page ({} bytes), try again.", acquired); // there is no enough memory actually, it means the actual free memory is smaller than // MemoryManager thought, we should keep the acquired memory. synchronized (this) { acquiredButNotUsed += acquired; allocatedPages.clear(pageNumber); } // this could trigger spilling to free some pages. return allocatePage(size, consumer); } page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); } return page; } 给定分配到的内存页和页内的偏移,生成一个64bits的逻辑地址 /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. * * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}. * @param offsetInPage an offset in this page which incorporates the base offset. In other words, * this should be the value that you would pass as the base offset into an * UNSAFE call (e.g. page.baseOffset() + something). * @return an encoded page address. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { if (!inHeap) { // In off-heap mode, an offset is an absolute address that may require a full 64 bits to // encode. Due to our page size limitation, though, we can convert this into an offset that's // relative to the page's base offset; this relative offset will fit in 51 bits. offsetInPage -= page.getBaseOffset(); } return encodePageNumberAndOffset(page.pageNumber, offsetInPage); } 高13bits是page number,低位为页内偏移 @VisibleForTesting public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); } 给定逻辑地址,获取page number @VisibleForTesting public static int decodePageNumber(long pagePlusOffsetAddress) { return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); } 给定逻辑地址,获取页内偏移 private static long decodeOffset(long pagePlusOffsetAddress) { return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); } 给定地址,获取内存页 /** * Get the page associated with an address encoded by * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public Object getPage(long pagePlusOffsetAddress) { if (inHeap) { final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); final MemoryBlock page = pageTable[pageNumber]; assert (page != null); assert (page.getBaseObject() != null); return page.getBaseObject(); } else { return null; } } 给定地址获取页内偏移 /** * Get the offset associated with an address encoded by * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public long getOffsetInPage(long pagePlusOffsetAddress) { final long offsetInPage = decodeOffset(pagePlusOffsetAddress); if (inHeap) { return offsetInPage; } else { // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we // converted the absolute address into a relative address. Here, we invert that operation: final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); final MemoryBlock page = pageTable[pageNumber]; assert (page != null); return page.getBaseOffset() + offsetInPage; } }
spark-storage
spark-storage中类的关系比较复杂,不在这里展开,列一下几个重要类:
- BlockId:
表示 Spark 中数据块的唯一标识符。
依赖关系:通常作为其他存储相关类的参数或属性,例如 BlockManager。
- BlockInfo:
包含有关数据块的元数据信息。
依赖关系:依赖于 BlockId,并且可以与 BlockManager 一起使用。
- BlockManager:
负责管理分布式数据块的存储和检索。
依赖关系:依赖于 BlockId、BlockInfo 等类,与 DiskStore、MemoryStore 等一起协同工作。
- BlockManagerMaster:
管理集群中所有 BlockManager 的主节点。
依赖关系:依赖于 BlockManager,与 BlockManagerId 等协同工作。
- BlockManagerId:
表示 BlockManager 的唯一标识符。
依赖关系:通常作为 BlockManagerMaster 的参数,用于标识不同的 BlockManager。
- BlockManagerMasterEndpoint:
BlockManagerMaster 与其他节点通信的端点。
依赖关系:依赖于 BlockManagerMaster,与 RpcEndpoint 等一起使用。
- DiskBlockManager:
BlockManager 的一个实现,负责将数据块持久化到磁盘。
依赖关系:依赖于 BlockManager 和 DiskStore,与 DiskStore 等一起工作。
- MemoryStore:
BlockManager 中负责将数据块存储在内存中的组件。
依赖关系:依赖于 BlockManager 和 MemoryManager,与 MemoryManager 等协同工作。
- DiskStore:
BlockManager 中负责将数据块持久化到磁盘的组件。
依赖关系:依赖于 BlockManager 和 DiskBlockManager。
- MemoryManager:
负责管理内存的组件,与 MemoryStore 等协同工作。
依赖关系:通常与 MemoryStore 和 BlockManager 一起使用。
- ShuffleBlockId:
用于表示与Shuffle相关的数据块的标识符。
依赖关系:依赖于 BlockId。
spark-unsafe
HeapMemoryAllocator实现了堆内存的实际分配
@GuardedBy("this") private final Map>> bufferPoolsBySize = new HashMap<>(); private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; /** * Returns true if allocations of the given size should go through the pooling mechanism and * false otherwise. */ private boolean shouldPool(long size) { // Very small allocations are less likely to benefit from pooling. return size >= POOLING_THRESHOLD_BYTES; }
这里使用一个弱引用的Long数组对于1M以上的回收内存进行资源池化,弱引用为了避免长时间未使用的数组一直保留在缓冲池中,消耗内存资源。
这也是spark内存使用不稳定的原因之一:弱引用对象的回收仍然是jvm控制的,没办法做到立即回收。
@Override public MemoryBlock allocate(long size) throws OutOfMemoryError { int numWords = (int) ((size + 7) / 8); long alignedSize = numWords * 8L; assert (alignedSize >= size); if (shouldPool(alignedSize)) { synchronized (this) { final LinkedList> pool = bufferPoolsBySize.get(alignedSize); if (pool != null) { while (!pool.isEmpty()) { final WeakReference arrayReference = pool.pop(); final long[] array = arrayReference.get(); if (array != null) { assert (array.length * 8L >= size); MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } return memory; } } bufferPoolsBySize.remove(alignedSize); } } } long[] array = new long[numWords]; MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } return memory; }
free的时候如果大于1M,则池化,否则清空引用
@Override public void free(MemoryBlock memory) { assert (memory.obj != null) : "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + "free()"; final long size = memory.size(); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } // Mark the page as freed (so we can detect double-frees). memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to null out its reference to the long[] array. long[] array = (long[]) memory.obj; memory.setObjAndOffset(null, 0); long alignedSize = ((size + 7) / 8) * 8; if (shouldPool(alignedSize)) { synchronized (this) { LinkedList> pool = bufferPoolsBySize.computeIfAbsent(alignedSize, k -> new LinkedList<>()); pool.add(new WeakReference<>(array)); } } }
猜你喜欢
网友评论
- 搜索
- 最新文章
- 热门文章