下面的代码使用 case 语句来确定应应用 manhattanDistance 或 eucleudianDistance 的距离函数 可以使用特征或 DRY 原则进一步概括此代码,使其更易于维护吗?
object general {
println("Welcome to the Scala worksheet") //> Welcome to the Scala worksheet
object DistanceOptions extends Enumeration {
type Dist = Value
val Manhattan, Eucleudian = Value
}
object DistanceFunctions {
def manhattanDistance(l1: (DataLine, DataLine)): Double = {
val t: List[(Double, Double)] = l1._1.points.zip(l1._2.points)
t.map(m => Math.abs(m._1 - m._2)).sum
}
def eucleudianDistance(l1: (DataLine, DataLine)): Double = {
val ld: List[(Double, Double)] = l1._1.points.zip(l1._2.points)
val sum = ld.map(m => Math.abs(m._1 - m._2) + Math.abs(m._1 - m._2)).sum
Math.sqrt(sum)
}
def getDistance(s: DistanceOptions.Dist, l1: (DataLine, DataLine)) = {
s match {
case DistanceOptions.Manhattan => DistanceFunctions.manhattanDistance(l1)
case DistanceOptions.Eucleudian => DistanceFunctions.eucleudianDistance(l1)
}
DistanceFunctions.manhattanDistance(l1)
DistanceFunctions.eucleudianDistance(l1)
}
}
case class DataLine(label: String, points: List[Double])
val l = (DataLine("a", List(1, 2)), DataLine("b", List(1, 2)))
//> l : (general.DataLine, general.DataLine) = (DataLine(a,List(1.0, 2.0)),Dat
//| aLine(b,List(1.0, 2.0)))
DistanceFunctions.getDistance(DistanceOptions.Manhattan, l)
//> res0: Double = 0.0
DistanceFunctions.getDistance(DistanceOptions.Eucleudian, l)
//> res1: Double = 0.0
}
使用类型类更新:
object gen extends App {
object DistanceOptions extends Enumeration {
type Dist = Value
val Manhattan, Eucleudian = Value
}
trait DistanceFunctionsType[T, A] {
def manhattanDistance(t: (T, T)): A
def eucleudianDistance(t: (T, T)): A
}
object DistanceFunctions extends DistanceFunctionsType[DataLine, Double] {
def manhattanDistance(l1: (DataLine, DataLine)): Double = {
val t: List[(Double, Double)] = l1._1.points.zip(l1._2.points)
t.map(m => Math.abs(m._1 - m._2)).sum
}
def eucleudianDistance(l1: (DataLine, DataLine)): Double = {
val ld: List[(Double, Double)] = l1._1.points.zip(l1._2.points)
val sum = ld.map(m => Math.abs(m._1 - m._2) + Math.abs(m._1 - m._2)).sum
Math.sqrt(sum)
}
def getDistance(distanceOptions: DistanceOptions.Dist, l1: (DataLine, DataLine)) = {
distanceOptions match {
case DistanceOptions.Manhattan => DistanceFunctions.manhattanDistance(l1)
case DistanceOptions.Eucleudian => DistanceFunctions.eucleudianDistance(l1)
}
}
}
case class DataLine(label: String, points: List[Double])
val l = (DataLine("a", List(1, 2)), DataLine("b", List(1, 2)))
println(DistanceFunctions.getDistance(DistanceOptions.Manhattan, l))
println(DistanceFunctions.getDistance(DistanceOptions.Eucleudian, l))
}
在实现这个结构时,我发现本指南很有帮助:http://danielwestheide.com/blog/2013/02/06/the-neophytes-guide-to-scala-part-12-type-classes.html
最佳答案
是——例如参见Spire's MetricSpace
,这将允许您编写如下内容:
case class DataLine(points: List[Double])
import spire.algebra._
object manhattanDistance extends MetricSpace[DataLine, Double] {
def distance(v: DataLine, w: DataLine): Double = {
val ld: List[(Double, Double)] = v.points.zip(w.points)
val sum = ld.map(m =>
math.abs(m._1 - m._2) + math.abs(m._1 - m._2)
).sum
math.sqrt(sum)
}
}
这种方法可以让你避免枚举,如果你使用 Spire 的实现,你会得到很好的运算符,这是一种干净的方法来测试你的实现是否满足例如the triangle inequality ,以及许多聪明人为您考虑性能、特化等方面的好处。
关于scala - 这些功能可以进一步推广吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/29348865/