Initial commit
This commit is contained in:
7
aurac_codegen/Cargo.toml
Normal file
7
aurac_codegen/Cargo.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
[package]
|
||||
name = "aurac_codegen"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
aurac_parser = { path = "../aurac_parser" }
|
||||
141
aurac_codegen/src/ir_gen.rs
Normal file
141
aurac_codegen/src/ir_gen.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
use aurac_parser::ast::{Program, Decl, FnDecl, Block, Stmt, Expr, BinaryOp, TypeExpr};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub struct IrGenerator {
|
||||
pub output: String,
|
||||
pub tmp_counter: usize,
|
||||
pub env: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl IrGenerator {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
output: String::new(),
|
||||
tmp_counter: 0,
|
||||
env: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_program(&mut self, program: &Program) -> String {
|
||||
for decl in &program.decls {
|
||||
if let Decl::Fn(fn_decl) = decl {
|
||||
self.generate_fn(fn_decl);
|
||||
}
|
||||
}
|
||||
self.output.clone()
|
||||
}
|
||||
|
||||
fn map_type(aura_type: &str) -> &'static str {
|
||||
match aura_type {
|
||||
"f32" | "f64" | "PositiveTime" => "float",
|
||||
"i32" | "i64" | "u32" | "u64" | "i8" | "i16" | "u8" | "u16" => "i32",
|
||||
"bool" => "i1",
|
||||
_ => "unknown_type",
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_fn(&mut self, decl: &FnDecl) {
|
||||
self.env.clear();
|
||||
self.tmp_counter = 0;
|
||||
|
||||
let ret_type_str = match &decl.return_type {
|
||||
TypeExpr::BaseType(bt) => bt.clone(),
|
||||
_ => "f32".to_string(), // fallback
|
||||
};
|
||||
let llvm_ret_type = Self::map_type(&ret_type_str);
|
||||
|
||||
if decl.is_gpu {
|
||||
self.output.push_str(&format!("define ptx_kernel {} @{}(", llvm_ret_type, decl.name));
|
||||
} else {
|
||||
self.output.push_str(&format!("define {} @{}(", llvm_ret_type, decl.name));
|
||||
}
|
||||
|
||||
for (i, param) in decl.params.iter().enumerate() {
|
||||
let param_type = match ¶m.ty {
|
||||
TypeExpr::BaseType(bt) => bt.clone(),
|
||||
_ => "f32".to_string(),
|
||||
};
|
||||
let llvm_param_type = Self::map_type(¶m_type);
|
||||
|
||||
self.output.push_str(&format!("{} %{}", llvm_param_type, param.name));
|
||||
|
||||
if i < decl.params.len() - 1 {
|
||||
self.output.push_str(", ");
|
||||
}
|
||||
}
|
||||
|
||||
if decl.is_gpu {
|
||||
self.output.push_str(") #0 {\nentry:\n");
|
||||
} else {
|
||||
self.output.push_str(") {\nentry:\n");
|
||||
}
|
||||
|
||||
self.generate_block(&decl.body, &ret_type_str);
|
||||
|
||||
self.output.push_str("}\n\n");
|
||||
|
||||
if decl.is_gpu {
|
||||
self.output.push_str("attributes #0 = { \"target-cpu\"=\"sm_70\" \"target-features\"=\"+ptx60\" }\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_block(&mut self, block: &Block, expected_ret_type: &str) {
|
||||
for stmt in &block.statements {
|
||||
self.generate_stmt(stmt, expected_ret_type);
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_stmt(&mut self, stmt: &Stmt, fn_ret_type: &str) {
|
||||
match stmt {
|
||||
Stmt::Return(expr) => {
|
||||
let val_reg = self.generate_expr(expr, fn_ret_type);
|
||||
let llvm_type = Self::map_type(fn_ret_type);
|
||||
self.output.push_str(&format!(" ret {} {}\n", llvm_type, val_reg));
|
||||
}
|
||||
Stmt::ExprStmt(expr) => {
|
||||
self.generate_expr(expr, fn_ret_type);
|
||||
}
|
||||
Stmt::LetBind(name, expr) => {
|
||||
// All test vars are f32 mathematically in this scenario
|
||||
let val_reg = self.generate_expr(expr, fn_ret_type);
|
||||
self.env.insert(name.clone(), val_reg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_expr(&mut self, expr: &Expr, expected_type: &str) -> String {
|
||||
match expr {
|
||||
Expr::Identifier(name) => {
|
||||
self.env.get(name).cloned().unwrap_or_else(|| format!("%{}", name))
|
||||
}
|
||||
Expr::Literal(val) => val.clone(),
|
||||
Expr::Binary(left, op, right) => {
|
||||
let left_val = self.generate_expr(left, expected_type);
|
||||
let right_val = self.generate_expr(right, expected_type);
|
||||
|
||||
let is_float = expected_type == "f32" || expected_type == "f64" || expected_type == "PositiveTime";
|
||||
let llvm_type = Self::map_type(expected_type);
|
||||
|
||||
let res_reg = format!("%{}", self.tmp_counter);
|
||||
self.tmp_counter += 1;
|
||||
|
||||
let instruction = match op {
|
||||
BinaryOp::Add => if is_float { "fadd" } else { "add" },
|
||||
BinaryOp::Sub => if is_float { "fsub" } else { "sub" },
|
||||
BinaryOp::Mul => if is_float { "fmul" } else { "mul" },
|
||||
BinaryOp::Div => if is_float { "fdiv" } else { "sdiv" },
|
||||
BinaryOp::Gt => if is_float { "fcmp ogt" } else { "icmp sgt" },
|
||||
BinaryOp::Lt => if is_float { "fcmp olt" } else { "icmp slt" },
|
||||
BinaryOp::Eq => if is_float { "fcmp oeq" } else { "icmp eq" },
|
||||
};
|
||||
|
||||
self.output.push_str(&format!(
|
||||
" {} = {} {} {}, {}\n",
|
||||
res_reg, instruction, llvm_type, left_val, right_val
|
||||
));
|
||||
|
||||
res_reg
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
37
aurac_codegen/src/lib.rs
Normal file
37
aurac_codegen/src/lib.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
pub mod ir_gen;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ir_gen::IrGenerator;
|
||||
use aurac_parser::ast::{Program, Decl, FnDecl, Param, TypeExpr, Block, Stmt, Expr, BinaryOp};
|
||||
|
||||
#[test]
|
||||
fn test_generate_add_fn() {
|
||||
let program = Program {
|
||||
decls: vec![Decl::Fn(FnDecl {
|
||||
is_pure: true,
|
||||
is_gpu: false,
|
||||
name: "add".to_string(),
|
||||
params: vec![
|
||||
Param { name: "a".to_string(), ty: TypeExpr::BaseType("f32".to_string()) },
|
||||
Param { name: "b".to_string(), ty: TypeExpr::BaseType("f32".to_string()) },
|
||||
],
|
||||
return_type: TypeExpr::BaseType("f32".to_string()),
|
||||
body: Block {
|
||||
statements: vec![Stmt::Return(Expr::Binary(
|
||||
Box::new(Expr::Identifier("a".to_string())),
|
||||
BinaryOp::Add,
|
||||
Box::new(Expr::Identifier("b".to_string())),
|
||||
))],
|
||||
},
|
||||
})],
|
||||
};
|
||||
|
||||
let mut generator = IrGenerator::new();
|
||||
let ir = generator.generate_program(&program);
|
||||
|
||||
assert!(ir.contains("define float @add(float %a, float %b)"));
|
||||
assert!(ir.contains("fadd float %a, %b"));
|
||||
assert!(ir.contains("ret float %0"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user