library(testthat) library(kernlab) # Split the iris dataset into training and test sets set.seed(123) # Set seed for reproducibility train_index <- sample(1:nrow(iris), size = 0.7 * nrow(iris)) # 70% for training train_data <- iris[train_index, ] test_data <- iris[-train_index, ] test_that("svm_analysis returns a valid model", { # Use the training set formula <- Species ~ Sepal.Length + Sepal.Width # Call the svm_analysis function to train the model model <- svm_analysis(train_data, formula) # Check if the returned object is of class 'ksvm' expect_true("ksvm" %in% class(model)) # If the model is a list, check if it contains a ksvm object if (is.list(model)) { expect_true(any(grepl("ksvm", class(model[[1]])))) # Check if the ksvm model has coefficients and support vectors expect_true(length(model[[1]]@coef) > 0) expect_true(model[[1]]@nSV > 0) } else { # If the model is not a list, directly check the model's structure expect_true(length(model@coef) > 0) expect_true(model@nSV > 0) } }) test_that("svm_predict returns correct predictions", { # Use the training set to train the model formula <- Species ~ Sepal.Length + Sepal.Width model <- svm_analysis(train_data, formula) # Use the model to make predictions on the test set predictions <- svm_predict(model, test_data) # Check if predictions are of the expected type (factor) and length expect_type(predictions, "integer") expect_equal(length(predictions), nrow(test_data)) }) test_that("svm_plot runs without error", { # Use the training set formula <- Species ~ Sepal.Length + Sepal.Width model <- svm_analysis(train_data, formula) # Check that the plotting function runs without error expect_error(svm_plot(model, train_data, formula), NA) }) test_that("svm_plot handles PCA for high-dimensional data", { # Use the training set with more features (high-dimensional data) formula <- Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width model <- svm_analysis(train_data, formula) # Check that the plotting function runs without error for PCA reduction expect_error(svm_plot(model, train_data, formula), NA) })