mirror of
https://github.com/hedge-dev/XenonRecomp.git
synced 2025-12-11 14:34:58 +00:00
Initial Commit
This commit is contained in:
264
thirdparty/capstone/suite/synctools/tablegen/include/llvm/IR/MatrixBuilder.h
vendored
Normal file
264
thirdparty/capstone/suite/synctools/tablegen/include/llvm/IR/MatrixBuilder.h
vendored
Normal file
@@ -0,0 +1,264 @@
|
||||
//===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the MatrixBuilder class, which is used as a convenient way
|
||||
// to lower matrix operations to LLVM IR.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLVM_IR_MATRIXBUILDER_H
|
||||
#define LLVM_IR_MATRIXBUILDER_H
|
||||
|
||||
#include "llvm/IR/Constant.h"
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/InstrTypes.h"
|
||||
#include "llvm/IR/Instruction.h"
|
||||
#include "llvm/IR/IntrinsicInst.h"
|
||||
#include "llvm/IR/Type.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "llvm/Support/Alignment.h"
|
||||
|
||||
namespace llvm {
|
||||
|
||||
class Function;
|
||||
class Twine;
|
||||
class Module;
|
||||
|
||||
template <class IRBuilderTy> class MatrixBuilder {
|
||||
IRBuilderTy &B;
|
||||
Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); }
|
||||
|
||||
std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS,
|
||||
Value *RHS) {
|
||||
assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) &&
|
||||
"One of the operands must be a matrix (embedded in a vector)");
|
||||
if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
|
||||
assert(!isa<ScalableVectorType>(LHS->getType()) &&
|
||||
"LHS Assumed to be fixed width");
|
||||
RHS = B.CreateVectorSplat(
|
||||
cast<VectorType>(LHS->getType())->getElementCount(), RHS,
|
||||
"scalar.splat");
|
||||
} else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
|
||||
assert(!isa<ScalableVectorType>(RHS->getType()) &&
|
||||
"RHS Assumed to be fixed width");
|
||||
LHS = B.CreateVectorSplat(
|
||||
cast<VectorType>(RHS->getType())->getElementCount(), LHS,
|
||||
"scalar.splat");
|
||||
}
|
||||
return {LHS, RHS};
|
||||
}
|
||||
|
||||
public:
|
||||
MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {}
|
||||
|
||||
/// Create a column major, strided matrix load.
|
||||
/// \p DataPtr - Start address of the matrix read
|
||||
/// \p Rows - Number of rows in matrix (must be a constant)
|
||||
/// \p Columns - Number of columns in matrix (must be a constant)
|
||||
/// \p Stride - Space between columns
|
||||
CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment,
|
||||
Value *Stride, bool IsVolatile, unsigned Rows,
|
||||
unsigned Columns, const Twine &Name = "") {
|
||||
|
||||
// Deal with the pointer
|
||||
PointerType *PtrTy = cast<PointerType>(DataPtr->getType());
|
||||
Type *EltTy = PtrTy->getPointerElementType();
|
||||
|
||||
auto *RetType = FixedVectorType::get(EltTy, Rows * Columns);
|
||||
|
||||
Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows),
|
||||
B.getInt32(Columns)};
|
||||
Type *OverloadedTypes[] = {RetType, Stride->getType()};
|
||||
|
||||
Function *TheFn = Intrinsic::getDeclaration(
|
||||
getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes);
|
||||
|
||||
CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
|
||||
Attribute AlignAttr =
|
||||
Attribute::getWithAlignment(Call->getContext(), Alignment);
|
||||
Call->addParamAttr(0, AlignAttr);
|
||||
return Call;
|
||||
}
|
||||
|
||||
/// Create a column major, strided matrix store.
|
||||
/// \p Matrix - Matrix to store
|
||||
/// \p Ptr - Pointer to write back to
|
||||
/// \p Stride - Space between columns
|
||||
CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment,
|
||||
Value *Stride, bool IsVolatile,
|
||||
unsigned Rows, unsigned Columns,
|
||||
const Twine &Name = "") {
|
||||
Value *Ops[] = {Matrix, Ptr,
|
||||
Stride, B.getInt1(IsVolatile),
|
||||
B.getInt32(Rows), B.getInt32(Columns)};
|
||||
Type *OverloadedTypes[] = {Matrix->getType(), Stride->getType()};
|
||||
|
||||
Function *TheFn = Intrinsic::getDeclaration(
|
||||
getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes);
|
||||
|
||||
CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
|
||||
Attribute AlignAttr =
|
||||
Attribute::getWithAlignment(Call->getContext(), Alignment);
|
||||
Call->addParamAttr(1, AlignAttr);
|
||||
return Call;
|
||||
}
|
||||
|
||||
/// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows
|
||||
/// rows and \p Columns columns.
|
||||
CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows,
|
||||
unsigned Columns, const Twine &Name = "") {
|
||||
auto *OpType = cast<VectorType>(Matrix->getType());
|
||||
auto *ReturnType =
|
||||
FixedVectorType::get(OpType->getElementType(), Rows * Columns);
|
||||
|
||||
Type *OverloadedTypes[] = {ReturnType};
|
||||
Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)};
|
||||
Function *TheFn = Intrinsic::getDeclaration(
|
||||
getModule(), Intrinsic::matrix_transpose, OverloadedTypes);
|
||||
|
||||
return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
|
||||
}
|
||||
|
||||
/// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
|
||||
/// RHS.
|
||||
CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
|
||||
unsigned LHSColumns, unsigned RHSColumns,
|
||||
const Twine &Name = "") {
|
||||
auto *LHSType = cast<VectorType>(LHS->getType());
|
||||
auto *RHSType = cast<VectorType>(RHS->getType());
|
||||
|
||||
auto *ReturnType =
|
||||
FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
|
||||
|
||||
Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns),
|
||||
B.getInt32(RHSColumns)};
|
||||
Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType};
|
||||
|
||||
Function *TheFn = Intrinsic::getDeclaration(
|
||||
getModule(), Intrinsic::matrix_multiply, OverloadedTypes);
|
||||
return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
|
||||
}
|
||||
|
||||
/// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p
|
||||
/// ColumnIdx).
|
||||
Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx,
|
||||
Value *ColumnIdx, unsigned NumRows) {
|
||||
return B.CreateInsertElement(
|
||||
Matrix, NewVal,
|
||||
B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get(
|
||||
ColumnIdx->getType(), NumRows)),
|
||||
RowIdx));
|
||||
}
|
||||
|
||||
/// Add matrixes \p LHS and \p RHS. Support both integer and floating point
|
||||
/// matrixes.
|
||||
Value *CreateAdd(Value *LHS, Value *RHS) {
|
||||
assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
|
||||
if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
|
||||
assert(!isa<ScalableVectorType>(LHS->getType()) &&
|
||||
"LHS Assumed to be fixed width");
|
||||
RHS = B.CreateVectorSplat(
|
||||
cast<VectorType>(LHS->getType())->getElementCount(), RHS,
|
||||
"scalar.splat");
|
||||
} else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
|
||||
assert(!isa<ScalableVectorType>(RHS->getType()) &&
|
||||
"RHS Assumed to be fixed width");
|
||||
LHS = B.CreateVectorSplat(
|
||||
cast<VectorType>(RHS->getType())->getElementCount(), LHS,
|
||||
"scalar.splat");
|
||||
}
|
||||
|
||||
return cast<VectorType>(LHS->getType())
|
||||
->getElementType()
|
||||
->isFloatingPointTy()
|
||||
? B.CreateFAdd(LHS, RHS)
|
||||
: B.CreateAdd(LHS, RHS);
|
||||
}
|
||||
|
||||
/// Subtract matrixes \p LHS and \p RHS. Support both integer and floating
|
||||
/// point matrixes.
|
||||
Value *CreateSub(Value *LHS, Value *RHS) {
|
||||
assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy());
|
||||
if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) {
|
||||
assert(!isa<ScalableVectorType>(LHS->getType()) &&
|
||||
"LHS Assumed to be fixed width");
|
||||
RHS = B.CreateVectorSplat(
|
||||
cast<VectorType>(LHS->getType())->getElementCount(), RHS,
|
||||
"scalar.splat");
|
||||
} else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) {
|
||||
assert(!isa<ScalableVectorType>(RHS->getType()) &&
|
||||
"RHS Assumed to be fixed width");
|
||||
LHS = B.CreateVectorSplat(
|
||||
cast<VectorType>(RHS->getType())->getElementCount(), LHS,
|
||||
"scalar.splat");
|
||||
}
|
||||
|
||||
return cast<VectorType>(LHS->getType())
|
||||
->getElementType()
|
||||
->isFloatingPointTy()
|
||||
? B.CreateFSub(LHS, RHS)
|
||||
: B.CreateSub(LHS, RHS);
|
||||
}
|
||||
|
||||
/// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p
|
||||
/// RHS.
|
||||
Value *CreateScalarMultiply(Value *LHS, Value *RHS) {
|
||||
std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS);
|
||||
if (LHS->getType()->getScalarType()->isFloatingPointTy())
|
||||
return B.CreateFMul(LHS, RHS);
|
||||
return B.CreateMul(LHS, RHS);
|
||||
}
|
||||
|
||||
/// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
|
||||
/// IsUnsigned indicates whether UDiv or SDiv should be used.
|
||||
Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
|
||||
assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
|
||||
assert(!isa<ScalableVectorType>(LHS->getType()) &&
|
||||
"LHS Assumed to be fixed width");
|
||||
RHS =
|
||||
B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
|
||||
RHS, "scalar.splat");
|
||||
return cast<VectorType>(LHS->getType())
|
||||
->getElementType()
|
||||
->isFloatingPointTy()
|
||||
? B.CreateFDiv(LHS, RHS)
|
||||
: (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
|
||||
}
|
||||
|
||||
/// Create an assumption that \p Idx is less than \p NumElements.
|
||||
void CreateIndexAssumption(Value *Idx, unsigned NumElements,
|
||||
Twine const &Name = "") {
|
||||
|
||||
Value *NumElts =
|
||||
B.getIntN(Idx->getType()->getScalarSizeInBits(), NumElements);
|
||||
auto *Cmp = B.CreateICmpULT(Idx, NumElts);
|
||||
if (auto *ConstCond = dyn_cast<ConstantInt>(Cmp))
|
||||
assert(ConstCond->isOne() && "Index must be valid!");
|
||||
else
|
||||
B.CreateAssumption(Cmp);
|
||||
}
|
||||
|
||||
/// Compute the index to access the element at (\p RowIdx, \p ColumnIdx) from
|
||||
/// a matrix with \p NumRows embedded in a vector.
|
||||
Value *CreateIndex(Value *RowIdx, Value *ColumnIdx, unsigned NumRows,
|
||||
Twine const &Name = "") {
|
||||
|
||||
unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(),
|
||||
ColumnIdx->getType()->getScalarSizeInBits());
|
||||
Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth);
|
||||
RowIdx = B.CreateZExt(RowIdx, IntTy);
|
||||
ColumnIdx = B.CreateZExt(ColumnIdx, IntTy);
|
||||
Value *NumRowsV = B.getIntN(MaxWidth, NumRows);
|
||||
return B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace llvm
|
||||
|
||||
#endif // LLVM_IR_MATRIXBUILDER_H
|
||||
Reference in New Issue
Block a user