Advertisement
gagarin_1982

спарк линрег фанк кастом

Aug 28th, 2024
51
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.96 KB | None | 0 0
  1. def train_linear_model(df, features, target='median_house_value'):
  2.  
  3. # Разделение признаков на категориальные и числовые
  4. categorical_cols = [col for col in features if df.schema[col].dataType == StringType()]
  5. numerical_cols = [col for col in features if df.schema[col].dataType != StringType()]
  6.  
  7. # Обработка категориальных признаков
  8. if categorical_cols:
  9. indexer = StringIndexer(
  10. inputCols=categorical_cols, outputCols=[c+'_idx' for c in categorical_cols]
  11. )
  12. df = indexer.fit(df).transform(df)
  13.  
  14. encoder = OneHotEncoder(
  15. inputCols=[c+'_idx' for c in categorical_cols], outputCols=[c+'_ohe' for c in categorical_cols]
  16. )
  17. df = encoder.fit(df).transform(df)
  18.  
  19. categorical_assembler = VectorAssembler(
  20. inputCols=[c+'_ohe' for c in categorical_cols], outputCol="categorical_features"
  21. )
  22. df = categorical_assembler.transform(df)
  23.  
  24. # Обработка числовых признаков
  25. if numerical_cols:
  26. numerical_assembler = VectorAssembler(inputCols=numerical_cols, outputCol="numerical_features")
  27. df = numerical_assembler.transform(df)
  28.  
  29. standardScaler = StandardScaler(inputCol='numerical_features', outputCol="numerical_features_scaled")
  30. df = standardScaler.fit(df).transform(df)
  31.  
  32. # Объединение всех признаков
  33. all_features = []
  34. if categorical_cols:
  35. all_features.append('categorical_features')
  36. if numerical_cols:
  37. all_features.append('numerical_features_scaled')
  38.  
  39. final_assembler = VectorAssembler(inputCols=all_features, outputCol="features")
  40. df = final_assembler.transform(df)
  41.  
  42. # Разделение на обучающую и тестовую выборки
  43. train_data, test_data = df.randomSplit([.8, .2], seed=2022)
  44.  
  45. # Обучение модели
  46. lr = LinearRegression(labelCol=target, featuresCol='features')
  47. model = lr.fit(train_data)
  48.  
  49. # Прогнозирование на тестовой выборке
  50. predictions = model.transform(test_data)
  51.  
  52. # Вычисление метрик
  53. evaluator_r2 = RegressionEvaluator(labelCol=target, predictionCol="prediction", metricName="r2")
  54. evaluator_mse = RegressionEvaluator(labelCol=target, predictionCol="prediction", metricName="mse")
  55. evaluator_rmse = RegressionEvaluator(labelCol=target, predictionCol="prediction", metricName="rmse")
  56.  
  57. r2 = evaluator_r2.evaluate(predictions)
  58. mse = evaluator_mse.evaluate(predictions)
  59. rmse = evaluator_rmse.evaluate(predictions)
  60.  
  61. print(f"R²: {r2}")
  62. print(f"MSE: {mse}")
  63. print(f"RMSE: {rmse}")
  64.  
  65.  
  66. features = [
  67. 'ocean_proximity', 'longitude', 'latitude', 'housing_median_age', 'total_rooms',
  68. 'total_bedrooms', 'population', 'households', 'median_income'
  69. ]
  70.  
  71. train_linear_model(df, features=features)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement