Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.stella.typecheck
- import org.syntax.stella.Absyn.*
- import org.syntax.stella.Absyn.List
- object TypeCheck {
- @Throws(Exception::class)
- fun typecheckProgram(program: Program) {
- val context = TypeContext()
- // Add built-in Nat functions to context
- val natType = TypeNat()
- val natFuncType = TypeFun(ListType().apply { add(natType) }, natType)
- val natBinOpType = TypeFun(ListType().apply { add(natType) }, natFuncType)
- // Add Nat::add
- context.addFunction("Nat::add", natBinOpType)
- // Add Nat::rec
- val recReturnFunc = TypeFun(ListType().apply { add(natType) }, natType)
- val recFunc = TypeFun(ListType().apply {
- add(natType) // n
- add(natType) // zero case
- add(TypeFun(ListType().apply { add(natType) }, // i
- TypeFun(ListType().apply { add(natType) }, natType))) // r
- }, natType)
- context.addFunction("Nat::rec", recFunc)
- when (program) {
- is AProgram -> {
- program.listdecl_.forEach { decl ->
- when (decl) {
- is DeclFun -> {
- val paramTypesListType = ListType().apply {
- addAll(decl.listparamdecl_.map {
- when (it) {
- is AParamDecl -> it.type_
- else -> throw TypeCheckException("Unknown parameter declaration")
- }
- })
- }
- val returnType = when (decl.returntype_) {
- is SomeReturnType -> decl.returntype_.type_
- else -> throw TypeCheckException("Function must have a return type")
- }
- context.addFunction(decl.stellaident_, TypeFun(paramTypesListType, returnType))
- }
- }
- }
- program.listdecl_.forEach { decl ->
- when (decl) {
- is DeclFun -> typecheckFunction(decl, context)
- }
- }
- }
- else -> throw TypeCheckException("Invalid program structure")
- }
- }
- private fun typecheckFunction(decl: DeclFun, context: TypeContext) {
- val localContext = context.createChildContext()
- // Add parameters to the context
- decl.listparamdecl_.forEach { param ->
- when (param) {
- is AParamDecl -> {
- localContext.addVariable(param.stellaident_, param.type_)
- }
- else -> throw TypeCheckException("Unknown parameter declaration")
- }
- }
- // Get the expected return type
- val expectedReturnType = when (decl.returntype_) {
- is SomeReturnType -> decl.returntype_.type_
- else -> throw TypeCheckException("Function must have a return type")
- }
- // Create function type and add it to the context
- val paramTypes = ListType().apply {
- addAll(decl.listparamdecl_.map {
- (it as AParamDecl).type_
- })
- }
- val functionType = TypeFun(paramTypes, expectedReturnType)
- localContext.addVariable(decl.stellaident_, functionType) // Add function to local scope
- val actualReturnType = typecheckExpr(decl.expr_, localContext)
- if (actualReturnType != expectedReturnType) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
- }
- }
- private fun typecheckExpr(expr: Expr, context: TypeContext): Type {
- return expr.accept(ExpressionTypeChecker(context), Unit)
- }
- private class ExpressionTypeChecker(private val context: TypeContext) : Expr.Visitor<Type, Unit> {
- override fun visit(p: Var, arg: Unit): Type {
- return context.getVariableType(p.stellaident_) ?: throw TypeCheckException("ERROR_UNDEFINED_VARIABLE")
- }
- override fun visit(p: ConstTrue, arg: Unit): Type = TypeBool()
- override fun visit(p: ConstFalse, arg: Unit): Type = TypeBool()
- override fun visit(p: ConstInt, arg: Unit): Type = TypeNat()
- override fun visit(p: Succ, arg: Unit): Type {
- val argType = p.expr_.accept(this, arg)
- if (argType !is TypeNat) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
- }
- return TypeNat()
- }
- override fun visit(p: IsZero, arg: Unit): Type {
- val argType = p.expr_.accept(this, arg)
- if (argType !is TypeNat) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
- }
- return TypeBool()
- }
- override fun visit(p: If, arg: Unit): Type {
- val condType = p.expr_1.accept(this, arg)
- if (condType !is TypeBool) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
- }
- val thenType = p.expr_2.accept(this, arg)
- val elseType = p.expr_3.accept(this, arg)
- if (thenType != elseType) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
- }
- return thenType
- }
- override fun visit(p: Abstraction, arg: Unit): Type {
- if (p.listparamdecl_.size != 1) {
- throw TypeCheckException("First-class functions must have exactly one parameter")
- }
- val localContext = context.createChildContext()
- val paramDecl =
- p.listparamdecl_[0] as? AParamDecl ?: throw TypeCheckException("Invalid parameter declaration")
- localContext.addVariable(paramDecl.stellaident_, paramDecl.type_)
- val bodyType = p.expr_.accept(ExpressionTypeChecker(localContext), arg)
- val paramTypeList = ListType().apply {
- add(paramDecl.type_)
- }
- return TypeFun(paramTypeList, bodyType)
- }
- override fun visit(p: Application, arg: Unit): Type {
- val funcType = p.expr_.accept(this, arg)
- if (funcType !is TypeFun) {
- throw TypeCheckException("ERROR_NOT_A_FUNCTION")
- }
- if (p.listexpr_.size != 1) {
- throw TypeCheckException("Function application must have exactly one argument")
- }
- val argType = p.listexpr_[0].accept(this, arg)
- if (argType != funcType.listtype_[0]) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_PARAMETER")
- }
- return funcType.type_
- }
- override fun visit(p: NatRec, arg: Unit): Type {
- val nType = p.expr_1.accept(this, arg)
- if (nType !is TypeNat) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
- }
- val zType = p.expr_2.accept(this, arg)
- val fType = p.expr_3.accept(this, arg)
- if (fType !is TypeFun) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
- }
- if (fType.listtype_.size != 1 || fType.listtype_[0] !is TypeNat) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_PARAMETER")
- }
- val innerFuncType = fType.type_
- if (innerFuncType !is TypeFun || innerFuncType.listtype_.size != 1 || innerFuncType.listtype_[0] != zType || innerFuncType.type_ != zType) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
- }
- return zType
- }
- override fun visit(p: Sequence?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Assign?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Let, arg: Unit): Type {
- // Create a new context for let bindings
- val letContext = context.createChildContext()
- // Process all pattern bindings (assuming they're all variables)
- p.listpatternbinding_.forEach { binding ->
- when (binding) {
- is APatternBinding -> {
- // Type check the expression being bound
- val bindingType = binding.expr_.accept(this, arg)
- // Get variable name from pattern (assuming it's always PatternVar)
- when (val pattern = binding.pattern_) {
- is PatternVar -> {
- // Add variable binding to context
- letContext.addVariable(pattern.stellaident_, bindingType)
- }
- else -> throw TypeCheckException("Only variable patterns are supported")
- }
- }
- else -> throw TypeCheckException("Invalid pattern binding")
- }
- }
- // Type check the body expression in the new context with bindings
- return p.expr_.accept(ExpressionTypeChecker(letContext), arg)
- }
- override fun visit(p: LetRec?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: TypeAbstraction?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: LessThan?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: LessThanOrEqual?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: GreaterThan?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: GreaterThanOrEqual?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Equal?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: NotEqual?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: TypeAsc?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: TypeCast?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Variant?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Match?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: List?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Add, arg: Unit): Type {
- val type1 = p.expr_1.accept(this, arg)
- val type2 = p.expr_2.accept(this, arg)
- if (type1 !is TypeNat || type2 !is TypeNat) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION: Addition requires Nat operands")
- }
- return TypeNat()
- }
- override fun visit(p: Subtract?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: LogicOr?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Multiply?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Divide?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: LogicAnd?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Ref?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Deref?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: TypeApplication?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: DotRecord?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: DotTuple, arg: Unit): Type {
- val tupleType = p.expr_.accept(this, arg)
- if (tupleType !is TypeTuple) {
- throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION: Expected a tuple type")
- }
- // Check if tuple has exactly 2 components
- if (tupleType.listtype_.size != 2) {
- throw TypeCheckException("Expected a pair (tuple with 2 components)")
- }
- // Check if the index is valid (1 or 2)
- when (p.integer_) {
- 1 -> return tupleType.listtype_[0]
- 2 -> return tupleType.listtype_[1]
- else -> throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION: Invalid tuple index, must be 1 or 2")
- }
- }
- override fun visit(p: Tuple, arg: Unit): Type {
- // Ensure exactly 2 components for pairs
- if (p.listexpr_.size != 2) {
- throw TypeCheckException("Pairs must have exactly 2 components")
- }
- // Type check both components
- val type1 = p.listexpr_[0].accept(this, arg)
- val type2 = p.listexpr_[1].accept(this, arg)
- // Create tuple type with both component types
- return TypeTuple(ListType().apply {
- add(type1)
- add(type2)
- })
- }
- override fun visit(p: Record?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: ConsList?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Head?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: IsEmpty?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Tail?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Panic?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Throw?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: TryCatch?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: TryWith?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: TryCastAs?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Inl?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Inr?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: LogicNot?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Pred?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Fix?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Fold?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: Unfold?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- override fun visit(p: ConstUnit, arg: Unit): Type = TypeUnit()
- override fun visit(p: ConstMemory?, arg: Unit?): Type {
- TODO("Not yet implemented")
- }
- }
- class TypeContext {
- private val variables = mutableMapOf<String, Type>()
- private val functions = mutableMapOf<String, TypeFun>()
- private val parent: TypeContext? = null
- fun createChildContext(): TypeContext {
- return TypeContext().apply {
- variables.putAll(this@TypeContext.variables)
- functions.putAll(this@TypeContext.functions)
- }
- }
- fun addVariable(name: String, type: Type) {
- variables[name] = type
- }
- fun getVariableType(name: String): Type? {
- return variables[name] ?: parent?.getVariableType(name)
- }
- fun addFunction(name: String, type: TypeFun) {
- functions[name] = type
- variables[name] = type
- }
- fun getFunctionType(name: String): TypeFun? {
- return functions[name] ?: parent?.getFunctionType(name)
- }
- }
- class TypeCheckException(message: String) : Exception(message)
- }
Add Comment
Please, Sign In to add comment