Fix incorrect SSA construction

This commit is contained in:
Dmitry Stogov 2023-10-12 10:17:04 +03:00
parent 1fbb2ac2ed
commit d51efd33d4

View File

@ -947,11 +947,38 @@ static void llvm2ir_bb_start(ir_ctx *ctx, LLVMBasicBlockRef bb, LLVMBasicBlockRe
}
}
static void llvm2ir_patch_merge(ir_ctx *ctx, ir_ref merge, uint32_t n, ir_ref ref)
static void llvm2ir_set_predecessor_ref(uint32_t b, ir_ref ref, uint32_t count, uint32_t *edges, ir_ref *refs)
{
IR_ASSERT(ctx->ir_base[merge].op == IR_MERGE || ctx->ir_base[merge].op == IR_LOOP_BEGIN);
ctx->ir_base[merge].op = IR_LOOP_BEGIN;
ir_MERGE_SET_OP(merge, n + 1, ref);
do {
if (*edges == b && !*refs) {
*refs = ref;
return;
}
edges++;
refs++;
count--;
} while (count);
IR_ASSERT(0);
}
static void llvm2ir_patch_merge(ir_ctx *ctx, ir_ref merge, ir_ref ref, uint32_t b, uint32_t *edges)
{
ir_insn *insn = &ctx->ir_base[merge];
ir_ref *ops = insn->ops + 1;
uint32_t count = insn->inputs_count;
IR_ASSERT(insn->op == IR_MERGE || insn->op == IR_LOOP_BEGIN);
insn->op = IR_LOOP_BEGIN;
do {
if (*edges == b && !*ops) {
*ops = ref;
return;
}
edges++;
ops++;
count--;
} while (count);
IR_ASSERT(0);
}
static uint32_t llvm2ir_compute_post_order(
@ -1004,7 +1031,7 @@ static int llvm2ir_func(ir_ctx *ctx, LLVMModuleRef module, LLVMValueRef func)
ir_ref ref, max_inputs_count;
ir_hashtab bb_hash;
ir_use_list *predecessors;
uint32_t *predecessor_edges, *predecessor_refs_count;
uint32_t *predecessor_edges;
ir_ref *inputs, *bb_starts, *predecessor_refs;
// TODO: function prototype
@ -1071,7 +1098,6 @@ static int llvm2ir_func(ir_ctx *ctx, LLVMModuleRef module, LLVMValueRef func)
}
predecessor_edges = ir_mem_malloc(sizeof(uint32_t) * count);
predecessor_refs = ir_mem_calloc(sizeof(ir_ref), count);
predecessor_refs_count = ir_mem_calloc(sizeof(uint32_t), bb_count);
inputs = ir_mem_malloc(sizeof(ir_ref) * max_inputs_count);
for (i = 0; i < bb_count; i++) {
bb = bbs[i];
@ -1125,12 +1151,12 @@ static int llvm2ir_func(ir_ctx *ctx, LLVMModuleRef module, LLVMValueRef func)
IR_ASSERT(b < bb_count);
if (predecessors[b].count > 1) {
if (ir_bitset_in(visited, b)) {
llvm2ir_patch_merge(ctx, bb_starts[b], predecessor_refs_count[b], ref);
llvm2ir_patch_merge(ctx, bb_starts[b], ref, i, predecessor_edges + predecessors[b].refs);
ctx->ir_base[ref].op = IR_LOOP_END;
} else {
predecessor_refs[predecessors[b].refs + predecessor_refs_count[b]] = ref;
llvm2ir_set_predecessor_ref(i, ref, predecessors[b].count,
predecessor_edges + predecessors[b].refs, predecessor_refs + predecessors[b].refs);
}
predecessor_refs_count[b]++;
}
} else {
ref = llvm2ir_if(ctx, insn);
@ -1139,11 +1165,11 @@ static int llvm2ir_func(ir_ctx *ctx, LLVMModuleRef module, LLVMValueRef func)
if (predecessors[b].count > 1) {
ir_IF_TRUE(ref);
if (ir_bitset_in(visited, b)) {
llvm2ir_patch_merge(ctx, bb_starts[b], predecessor_refs_count[b], ir_LOOP_END());
llvm2ir_patch_merge(ctx, bb_starts[b], ir_LOOP_END(), i, predecessor_edges + predecessors[b].refs);
} else {
predecessor_refs[predecessors[b].refs + predecessor_refs_count[b]] = ir_END();;
llvm2ir_set_predecessor_ref(i, ir_END(), predecessors[b].count,
predecessor_edges + predecessors[b].refs, predecessor_refs + predecessors[b].refs);
}
predecessor_refs_count[b]++;
} else {
IR_ASSERT(!ir_bitset_in(visited, b));
}
@ -1152,11 +1178,11 @@ static int llvm2ir_func(ir_ctx *ctx, LLVMModuleRef module, LLVMValueRef func)
if (predecessors[b].count > 1) {
ir_IF_FALSE(ref);
if (ir_bitset_in(visited, b)) {
llvm2ir_patch_merge(ctx, bb_starts[b], predecessor_refs_count[b], ir_LOOP_END());
llvm2ir_patch_merge(ctx, bb_starts[b], ir_LOOP_END(), i, predecessor_edges + predecessors[b].refs);
} else {
predecessor_refs[predecessors[b].refs + predecessor_refs_count[b]] = ir_END();;
llvm2ir_set_predecessor_ref(i, ir_END(), predecessors[b].count,
predecessor_edges + predecessors[b].refs, predecessor_refs + predecessors[b].refs);
}
predecessor_refs_count[b]++;
} else {
IR_ASSERT(!ir_bitset_in(visited, b));
}
@ -1169,11 +1195,11 @@ static int llvm2ir_func(ir_ctx *ctx, LLVMModuleRef module, LLVMValueRef func)
if (predecessors[b].count > 1) {
ir_CASE_DEFAULT(ref);
if (ir_bitset_in(visited, b)) {
llvm2ir_patch_merge(ctx, bb_starts[b], predecessor_refs_count[b], ir_LOOP_END());
llvm2ir_patch_merge(ctx, bb_starts[b], ir_LOOP_END(), i, predecessor_edges + predecessors[b].refs);
} else {
predecessor_refs[predecessors[b].refs + predecessor_refs_count[b]] = ir_END();
llvm2ir_set_predecessor_ref(i, ir_END(), predecessors[b].count,
predecessor_edges + predecessors[b].refs, predecessor_refs + predecessors[b].refs);
}
predecessor_refs_count[b]++;
} else {
IR_ASSERT(!ir_bitset_in(visited, b));
}
@ -1186,11 +1212,11 @@ static int llvm2ir_func(ir_ctx *ctx, LLVMModuleRef module, LLVMValueRef func)
ir_type type = llvm2ir_type(LLVMTypeOf(val));
ir_CASE_VAL(ref, llvm2ir_op(ctx, val, type));
if (ir_bitset_in(visited, b)) {
llvm2ir_patch_merge(ctx, bb_starts[b], predecessor_refs_count[b], ir_LOOP_END());
llvm2ir_patch_merge(ctx, bb_starts[b], ir_LOOP_END(), i, predecessor_edges + predecessors[b].refs);
} else {
predecessor_refs[predecessors[b].refs + predecessor_refs_count[b]] = ir_END();
llvm2ir_set_predecessor_ref(i, ir_END(), predecessors[b].count,
predecessor_edges + predecessors[b].refs, predecessor_refs + predecessors[b].refs);
}
predecessor_refs_count[b]++;
} else {
IR_ASSERT(!ir_bitset_in(visited, b));
}
@ -1356,7 +1382,6 @@ static int llvm2ir_func(ir_ctx *ctx, LLVMModuleRef module, LLVMValueRef func)
ir_mem_free(visited);
ir_mem_free(post_order);
ir_mem_free(predecessor_refs);
ir_mem_free(predecessor_refs_count);
ir_mem_free(bb_starts);
ir_addrtab_free(ctx->binding);
ir_mem_free(ctx->binding);