EBobkunov

compiler part

Feb 1st, 2025
64
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Kotlin 16.77 KB | None | 0 0
  1. package org.stella.typecheck
  2.  
  3. import org.syntax.stella.Absyn.*
  4. import org.syntax.stella.Absyn.List
  5.  
  6. object TypeCheck {
  7.     @Throws(Exception::class)
  8.     fun typecheckProgram(program: Program) {
  9.         val context = TypeContext()
  10.  
  11.         // Add built-in Nat functions to context
  12.         val natType = TypeNat()
  13.         val natFuncType = TypeFun(ListType().apply { add(natType) }, natType)
  14.         val natBinOpType = TypeFun(ListType().apply { add(natType) }, natFuncType)
  15.  
  16.         // Add Nat::add
  17.         context.addFunction("Nat::add", natBinOpType)
  18.  
  19.         // Add Nat::rec
  20.         val recReturnFunc = TypeFun(ListType().apply { add(natType) }, natType)
  21.         val recFunc = TypeFun(ListType().apply {
  22.             add(natType) // n
  23.             add(natType) // zero case
  24.             add(TypeFun(ListType().apply { add(natType) }, // i
  25.                 TypeFun(ListType().apply { add(natType) }, natType))) // r
  26.         }, natType)
  27.         context.addFunction("Nat::rec", recFunc)
  28.  
  29.         when (program) {
  30.             is AProgram -> {
  31.                 program.listdecl_.forEach { decl ->
  32.                     when (decl) {
  33.                         is DeclFun -> {
  34.                             val paramTypesListType = ListType().apply {
  35.                                 addAll(decl.listparamdecl_.map {
  36.                                     when (it) {
  37.                                         is AParamDecl -> it.type_
  38.                                         else -> throw TypeCheckException("Unknown parameter declaration")
  39.                                     }
  40.                                 })
  41.                             }
  42.                             val returnType = when (decl.returntype_) {
  43.                                 is SomeReturnType -> decl.returntype_.type_
  44.                                 else -> throw TypeCheckException("Function must have a return type")
  45.                             }
  46.                             context.addFunction(decl.stellaident_, TypeFun(paramTypesListType, returnType))
  47.                         }
  48.                     }
  49.                 }
  50.  
  51.                 program.listdecl_.forEach { decl ->
  52.                     when (decl) {
  53.                         is DeclFun -> typecheckFunction(decl, context)
  54.                     }
  55.                 }
  56.             }
  57.             else -> throw TypeCheckException("Invalid program structure")
  58.         }
  59.     }
  60.  
  61.     private fun typecheckFunction(decl: DeclFun, context: TypeContext) {
  62.         val localContext = context.createChildContext()
  63.  
  64.         // Add parameters to the context
  65.         decl.listparamdecl_.forEach { param ->
  66.             when (param) {
  67.                 is AParamDecl -> {
  68.                     localContext.addVariable(param.stellaident_, param.type_)
  69.                 }
  70.  
  71.                 else -> throw TypeCheckException("Unknown parameter declaration")
  72.             }
  73.         }
  74.  
  75.         // Get the expected return type
  76.         val expectedReturnType = when (decl.returntype_) {
  77.             is SomeReturnType -> decl.returntype_.type_
  78.             else -> throw TypeCheckException("Function must have a return type")
  79.         }
  80.  
  81.         // Create function type and add it to the context
  82.         val paramTypes = ListType().apply {
  83.             addAll(decl.listparamdecl_.map {
  84.                 (it as AParamDecl).type_
  85.             })
  86.         }
  87.         val functionType = TypeFun(paramTypes, expectedReturnType)
  88.         localContext.addVariable(decl.stellaident_, functionType) // Add function to local scope
  89.  
  90.         val actualReturnType = typecheckExpr(decl.expr_, localContext)
  91.         if (actualReturnType != expectedReturnType) {
  92.             throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
  93.         }
  94.     }
  95.  
  96.     private fun typecheckExpr(expr: Expr, context: TypeContext): Type {
  97.         return expr.accept(ExpressionTypeChecker(context), Unit)
  98.     }
  99.  
  100.     private class ExpressionTypeChecker(private val context: TypeContext) : Expr.Visitor<Type, Unit> {
  101.         override fun visit(p: Var, arg: Unit): Type {
  102.             return context.getVariableType(p.stellaident_) ?: throw TypeCheckException("ERROR_UNDEFINED_VARIABLE")
  103.         }
  104.  
  105.         override fun visit(p: ConstTrue, arg: Unit): Type = TypeBool()
  106.         override fun visit(p: ConstFalse, arg: Unit): Type = TypeBool()
  107.         override fun visit(p: ConstInt, arg: Unit): Type = TypeNat()
  108.  
  109.         override fun visit(p: Succ, arg: Unit): Type {
  110.             val argType = p.expr_.accept(this, arg)
  111.             if (argType !is TypeNat) {
  112.                 throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
  113.             }
  114.             return TypeNat()
  115.         }
  116.  
  117.         override fun visit(p: IsZero, arg: Unit): Type {
  118.             val argType = p.expr_.accept(this, arg)
  119.             if (argType !is TypeNat) {
  120.                 throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
  121.             }
  122.             return TypeBool()
  123.         }
  124.  
  125.         override fun visit(p: If, arg: Unit): Type {
  126.             val condType = p.expr_1.accept(this, arg)
  127.             if (condType !is TypeBool) {
  128.                 throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
  129.             }
  130.  
  131.             val thenType = p.expr_2.accept(this, arg)
  132.             val elseType = p.expr_3.accept(this, arg)
  133.  
  134.             if (thenType != elseType) {
  135.                 throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
  136.             }
  137.             return thenType
  138.         }
  139.  
  140.         override fun visit(p: Abstraction, arg: Unit): Type {
  141.             if (p.listparamdecl_.size != 1) {
  142.                 throw TypeCheckException("First-class functions must have exactly one parameter")
  143.             }
  144.  
  145.             val localContext = context.createChildContext()
  146.             val paramDecl =
  147.                 p.listparamdecl_[0] as? AParamDecl ?: throw TypeCheckException("Invalid parameter declaration")
  148.  
  149.             localContext.addVariable(paramDecl.stellaident_, paramDecl.type_)
  150.             val bodyType = p.expr_.accept(ExpressionTypeChecker(localContext), arg)
  151.  
  152.             val paramTypeList = ListType().apply {
  153.                 add(paramDecl.type_)
  154.             }
  155.             return TypeFun(paramTypeList, bodyType)
  156.         }
  157.  
  158.         override fun visit(p: Application, arg: Unit): Type {
  159.             val funcType = p.expr_.accept(this, arg)
  160.             if (funcType !is TypeFun) {
  161.                 throw TypeCheckException("ERROR_NOT_A_FUNCTION")
  162.             }
  163.  
  164.             if (p.listexpr_.size != 1) {
  165.                 throw TypeCheckException("Function application must have exactly one argument")
  166.             }
  167.  
  168.             val argType = p.listexpr_[0].accept(this, arg)
  169.             if (argType != funcType.listtype_[0]) {
  170.                 throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_PARAMETER")
  171.             }
  172.  
  173.             return funcType.type_
  174.         }
  175.  
  176.         override fun visit(p: NatRec, arg: Unit): Type {
  177.  
  178.             val nType = p.expr_1.accept(this, arg)
  179.             if (nType !is TypeNat) {
  180.                 throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
  181.             }
  182.  
  183.  
  184.             val zType = p.expr_2.accept(this, arg)
  185.  
  186.  
  187.             val fType = p.expr_3.accept(this, arg)
  188.             if (fType !is TypeFun) {
  189.                 throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
  190.             }
  191.  
  192.  
  193.             if (fType.listtype_.size != 1 || fType.listtype_[0] !is TypeNat) {
  194.                 throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_PARAMETER")
  195.             }
  196.  
  197.             val innerFuncType = fType.type_
  198.             if (innerFuncType !is TypeFun || innerFuncType.listtype_.size != 1 || innerFuncType.listtype_[0] != zType || innerFuncType.type_ != zType) {
  199.                 throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION")
  200.             }
  201.  
  202.             return zType
  203.         }
  204.  
  205.         override fun visit(p: Sequence?, arg: Unit?): Type {
  206.             TODO("Not yet implemented")
  207.         }
  208.  
  209.         override fun visit(p: Assign?, arg: Unit?): Type {
  210.             TODO("Not yet implemented")
  211.         }
  212.  
  213.         override fun visit(p: Let, arg: Unit): Type {
  214.             // Create a new context for let bindings
  215.             val letContext = context.createChildContext()
  216.  
  217.             // Process all pattern bindings (assuming they're all variables)
  218.             p.listpatternbinding_.forEach { binding ->
  219.                 when (binding) {
  220.                     is APatternBinding -> {
  221.                         // Type check the expression being bound
  222.                         val bindingType = binding.expr_.accept(this, arg)
  223.  
  224.                         // Get variable name from pattern (assuming it's always PatternVar)
  225.                         when (val pattern = binding.pattern_) {
  226.                             is PatternVar -> {
  227.                                 // Add variable binding to context
  228.                                 letContext.addVariable(pattern.stellaident_, bindingType)
  229.                             }
  230.                             else -> throw TypeCheckException("Only variable patterns are supported")
  231.                         }
  232.                     }
  233.                     else -> throw TypeCheckException("Invalid pattern binding")
  234.                 }
  235.             }
  236.  
  237.             // Type check the body expression in the new context with bindings
  238.             return p.expr_.accept(ExpressionTypeChecker(letContext), arg)
  239.         }
  240.  
  241.         override fun visit(p: LetRec?, arg: Unit?): Type {
  242.             TODO("Not yet implemented")
  243.         }
  244.  
  245.         override fun visit(p: TypeAbstraction?, arg: Unit?): Type {
  246.             TODO("Not yet implemented")
  247.         }
  248.  
  249.         override fun visit(p: LessThan?, arg: Unit?): Type {
  250.             TODO("Not yet implemented")
  251.         }
  252.  
  253.         override fun visit(p: LessThanOrEqual?, arg: Unit?): Type {
  254.             TODO("Not yet implemented")
  255.         }
  256.  
  257.         override fun visit(p: GreaterThan?, arg: Unit?): Type {
  258.             TODO("Not yet implemented")
  259.         }
  260.  
  261.         override fun visit(p: GreaterThanOrEqual?, arg: Unit?): Type {
  262.             TODO("Not yet implemented")
  263.         }
  264.  
  265.         override fun visit(p: Equal?, arg: Unit?): Type {
  266.             TODO("Not yet implemented")
  267.         }
  268.  
  269.         override fun visit(p: NotEqual?, arg: Unit?): Type {
  270.             TODO("Not yet implemented")
  271.         }
  272.  
  273.         override fun visit(p: TypeAsc?, arg: Unit?): Type {
  274.             TODO("Not yet implemented")
  275.         }
  276.  
  277.         override fun visit(p: TypeCast?, arg: Unit?): Type {
  278.             TODO("Not yet implemented")
  279.         }
  280.  
  281.         override fun visit(p: Variant?, arg: Unit?): Type {
  282.             TODO("Not yet implemented")
  283.         }
  284.  
  285.         override fun visit(p: Match?, arg: Unit?): Type {
  286.             TODO("Not yet implemented")
  287.         }
  288.  
  289.         override fun visit(p: List?, arg: Unit?): Type {
  290.             TODO("Not yet implemented")
  291.         }
  292.  
  293.         override fun visit(p: Add, arg: Unit): Type {
  294.             val type1 = p.expr_1.accept(this, arg)
  295.             val type2 = p.expr_2.accept(this, arg)
  296.  
  297.             if (type1 !is TypeNat || type2 !is TypeNat) {
  298.                 throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION: Addition requires Nat operands")
  299.             }
  300.  
  301.             return TypeNat()
  302.         }
  303.  
  304.         override fun visit(p: Subtract?, arg: Unit?): Type {
  305.             TODO("Not yet implemented")
  306.         }
  307.  
  308.         override fun visit(p: LogicOr?, arg: Unit?): Type {
  309.             TODO("Not yet implemented")
  310.         }
  311.  
  312.         override fun visit(p: Multiply?, arg: Unit?): Type {
  313.             TODO("Not yet implemented")
  314.         }
  315.  
  316.         override fun visit(p: Divide?, arg: Unit?): Type {
  317.             TODO("Not yet implemented")
  318.         }
  319.  
  320.         override fun visit(p: LogicAnd?, arg: Unit?): Type {
  321.             TODO("Not yet implemented")
  322.         }
  323.  
  324.         override fun visit(p: Ref?, arg: Unit?): Type {
  325.             TODO("Not yet implemented")
  326.         }
  327.  
  328.         override fun visit(p: Deref?, arg: Unit?): Type {
  329.             TODO("Not yet implemented")
  330.         }
  331.  
  332.         override fun visit(p: TypeApplication?, arg: Unit?): Type {
  333.             TODO("Not yet implemented")
  334.         }
  335.  
  336.         override fun visit(p: DotRecord?, arg: Unit?): Type {
  337.             TODO("Not yet implemented")
  338.         }
  339.  
  340.         override fun visit(p: DotTuple, arg: Unit): Type {
  341.             val tupleType = p.expr_.accept(this, arg)
  342.             if (tupleType !is TypeTuple) {
  343.                 throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION: Expected a tuple type")
  344.             }
  345.  
  346.             // Check if tuple has exactly 2 components
  347.             if (tupleType.listtype_.size != 2) {
  348.                 throw TypeCheckException("Expected a pair (tuple with 2 components)")
  349.             }
  350.  
  351.             // Check if the index is valid (1 or 2)
  352.             when (p.integer_) {
  353.                 1 -> return tupleType.listtype_[0]
  354.                 2 -> return tupleType.listtype_[1]
  355.                 else -> throw TypeCheckException("ERROR_UNEXPECTED_TYPE_FOR_EXPRESSION: Invalid tuple index, must be 1 or 2")
  356.             }
  357.         }
  358.  
  359.         override fun visit(p: Tuple, arg: Unit): Type {
  360.             // Ensure exactly 2 components for pairs
  361.             if (p.listexpr_.size != 2) {
  362.                 throw TypeCheckException("Pairs must have exactly 2 components")
  363.             }
  364.  
  365.             // Type check both components
  366.             val type1 = p.listexpr_[0].accept(this, arg)
  367.             val type2 = p.listexpr_[1].accept(this, arg)
  368.  
  369.             // Create tuple type with both component types
  370.             return TypeTuple(ListType().apply {
  371.                 add(type1)
  372.                 add(type2)
  373.             })
  374.         }
  375.  
  376.         override fun visit(p: Record?, arg: Unit?): Type {
  377.             TODO("Not yet implemented")
  378.         }
  379.  
  380.         override fun visit(p: ConsList?, arg: Unit?): Type {
  381.             TODO("Not yet implemented")
  382.         }
  383.  
  384.         override fun visit(p: Head?, arg: Unit?): Type {
  385.             TODO("Not yet implemented")
  386.         }
  387.  
  388.         override fun visit(p: IsEmpty?, arg: Unit?): Type {
  389.             TODO("Not yet implemented")
  390.         }
  391.  
  392.         override fun visit(p: Tail?, arg: Unit?): Type {
  393.             TODO("Not yet implemented")
  394.         }
  395.  
  396.         override fun visit(p: Panic?, arg: Unit?): Type {
  397.             TODO("Not yet implemented")
  398.         }
  399.  
  400.         override fun visit(p: Throw?, arg: Unit?): Type {
  401.             TODO("Not yet implemented")
  402.         }
  403.  
  404.         override fun visit(p: TryCatch?, arg: Unit?): Type {
  405.             TODO("Not yet implemented")
  406.         }
  407.  
  408.         override fun visit(p: TryWith?, arg: Unit?): Type {
  409.             TODO("Not yet implemented")
  410.         }
  411.  
  412.         override fun visit(p: TryCastAs?, arg: Unit?): Type {
  413.             TODO("Not yet implemented")
  414.         }
  415.  
  416.         override fun visit(p: Inl?, arg: Unit?): Type {
  417.             TODO("Not yet implemented")
  418.         }
  419.  
  420.         override fun visit(p: Inr?, arg: Unit?): Type {
  421.             TODO("Not yet implemented")
  422.         }
  423.  
  424.         override fun visit(p: LogicNot?, arg: Unit?): Type {
  425.             TODO("Not yet implemented")
  426.         }
  427.  
  428.         override fun visit(p: Pred?, arg: Unit?): Type {
  429.             TODO("Not yet implemented")
  430.         }
  431.  
  432.         override fun visit(p: Fix?, arg: Unit?): Type {
  433.             TODO("Not yet implemented")
  434.         }
  435.  
  436.         override fun visit(p: Fold?, arg: Unit?): Type {
  437.             TODO("Not yet implemented")
  438.         }
  439.  
  440.         override fun visit(p: Unfold?, arg: Unit?): Type {
  441.             TODO("Not yet implemented")
  442.         }
  443.  
  444.         override fun visit(p: ConstUnit, arg: Unit): Type = TypeUnit()
  445.  
  446.         override fun visit(p: ConstMemory?, arg: Unit?): Type {
  447.             TODO("Not yet implemented")
  448.         }
  449.     }
  450.  
  451.     class TypeContext {
  452.         private val variables = mutableMapOf<String, Type>()
  453.         private val functions = mutableMapOf<String, TypeFun>()
  454.         private val parent: TypeContext? = null
  455.  
  456.         fun createChildContext(): TypeContext {
  457.             return TypeContext().apply {
  458.                 variables.putAll(this@TypeContext.variables)
  459.                 functions.putAll(this@TypeContext.functions)
  460.             }
  461.         }
  462.  
  463.         fun addVariable(name: String, type: Type) {
  464.             variables[name] = type
  465.         }
  466.  
  467.         fun getVariableType(name: String): Type? {
  468.             return variables[name] ?: parent?.getVariableType(name)
  469.         }
  470.  
  471.         fun addFunction(name: String, type: TypeFun) {
  472.             functions[name] = type
  473.             variables[name] = type
  474.         }
  475.  
  476.         fun getFunctionType(name: String): TypeFun? {
  477.             return functions[name] ?: parent?.getFunctionType(name)
  478.         }
  479.     }
  480.  
  481.     class TypeCheckException(message: String) : Exception(message)
  482. }
  483.  
Add Comment
Please, Sign In to add comment