• בלוג
  • עמוד 26
  • יום 17 של Advent Of Code 2023 - היום בו סקאלה עבדה נגדי

יום 17 של Advent Of Code 2023 - היום בו סקאלה עבדה נגדי

24/04/2024

אני לא יודע אם זה הקוד שלי או משהו מובנה בשפה, אבל התרגיל של יום 17 היה הכי מאתגר עד כה. בפעם הראשונה בסידרה הרגשתי שמבני הנתונים ה Immutable של סקאלה שכל כך כיף לעבוד איתם לא מצליחים לספק את הביצועים הנדרשים ועברתי להשתמש ב Mutable Collections. בואו נראה את התרגיל והפיתרון וכמו תמיד תרגישו בנוח להציע פיתרונות טובים יותר או יעילים יותר למקרה שפספסתי משהו.

1. האתגר - חיפוש מסלול קצר ביותר בגרף

האתגר של יום 17 הוא מימוש כמעט קלאסי של אלגוריתם מציאת המסלול הקצר ביותר בגרף. המסלול שלנו בנוי בצורת גריד של מספרים לדוגמה:

2413432311323
3215453535623
3255245654254
3446585845452
4546657867536
1438598798454
4457876987766
3637877979653
4654967986887
4564679986453
1224686865563
2546548887735
4322674655533

אנחנו מתחילים מהנקודה השמאלית עליונה וצריכים להגיע לפינה הימנית תחתונה של הלוח, ולמצוא את המסלול שסכום המספרים עליו הוא הנמוך ביותר (לא סופרים את ה-2 של נקודת ההתחלה, אלא אם כן נעבור בה פעם שנייה). יש גם טוויסט כי אסור שיהיו במסלול יותר מ-3 צעדים באותו כיוון.

בשביל למצוא מסלול קצר ביותר אנחנו יוצאים מנקודת ההתחלה ומסמנים את כל הנקודות אליהן אפשר להגיע, ולכל נקודה שומרים מה ה"ניקוד" של אותה נקודה. שמים את נקודת ההתחלה בצד ומכל הנקודות שנשארו מוצאים את זו עם הניקוד הנמוך ביותר וממשיכים ממנה את המסלול כדי לקבל את הנקודות אליהן אפשר להגיע מנקודה זו, ואז גם את הנקודה השנייה שמים בצד ושוב ומוצאים את הנקודה עם הניקוד הכי נמוך וממשיכים ממנה את המסלול. כך ממשיכים עד שמגיעים לסוף.

הניקוד בתרגיל הזה מורכב מסכום של שני דברים: המרחק מהיעד וסכום המספרים שהוביל אותנו לנקודה. כשלוקחים כל פעם את הנקודה עם הניקוד הכי נמוך וממשיכים ממנה אנחנו מבטיחים שנתקדם בכיוון הנכון ובמסלול עם הסכום הנמוך ביותר.

2. פיתרון בסקאלה

לא משנה כמה ניסיתי לכתוב את רקורסיית הזנב בצורה יעילה, בסופו של דבר חיפוש הנקודה הבאה מתוך אוסף כל הנקודות הפוטנציאליות לקח יותר מדי זמן - ורק גדל ככל שאספתי יותר אפשרויות. מה שהייתי צריך כאן היה scala.collection.mutable.PriorityQueue שהוא חלק ממבני הנתונים ה Mutable של סקאלה.

נקודה נוספת חשובה מהפיתרון היא ההתמודדות עם המגבלה של שלושה צעדים מקסימליים לאותו כיוון. בשביל זה הגדרתי לכל צומת גם את הכיוון בו הגעתי אליו, וכך התיחסתי לאותו צומת בצורה שונה אם הגעתי אליו בצורה אנכית או אופקית.

איך עובד הקוד? קודם כל מגדירים את הצומת עם הפונקציה שמחזירה את כל השכנים:

  case class Node(pos: (Int, Int),
                  direction: Direction,
                  count: Int,
                  map: Map[(Int, Int), Int]) {

    def move(newPosition: (Int, Int), newDirection: Direction): Option[Node] =
      if (((count == 3) && newDirection == direction) || (!map.contains(newPosition))) {
        None
      } else {
        Some(Node(
          pos=newPosition,
          direction=newDirection,
          count=if ((newDirection == direction) || (direction == Direction.Start)) count + 1 else 1,
          map=map))
      }

    def movePart2(newPosition: (Int, Int), newDirection: Direction): Option[Node] =
      if (!map.contains(newPosition)) {
        return None
      }

      if ((newDirection == direction) && (count >= 10)) {
        return None
      }

      if ((newDirection != direction) && (direction != Direction.Start) && (count < 4)) {
        return None
      }

      Some(Node(
        pos = newPosition,
        direction = newDirection,
        count = if ((newDirection == direction) || (direction == Direction.Start)) count + 1 else 1,
        map = map))


    def up: Option[Node] = if (direction == Direction.Down) {
      None
    } else {
      move(newPosition=(pos._1 - 1, pos._2), newDirection=Direction.Up)
    }

    def down: Option[Node] = if (direction == Direction.Up) {
      None
    } else {
      move(newPosition=(pos._1 + 1, pos._2), newDirection=Direction.Down)
    }

    def left: Option[Node] = if (direction == Direction.Right) {
      None
    } else {
      move(newPosition=(pos._1, pos._2 - 1), newDirection=Direction.Left)
    }

    def right: Option[Node] = if (direction == Direction.Left) {
      None
    } else {
      move(newPosition=(pos._1, pos._2 + 1), newDirection=Direction.Right)
    }

    def neighbors(): List[Node] =
      val (row, column) = pos

      List(up, left, right, down)
        .filter(_.isDefined)
        .map(_.get)
  }

אחרי זה מגדירים את פונקציית המרחק מהיעד:

def h(current: (Int, Int), goal: (Int, Int)): Int =
   (goal._1 - current._1) + (goal._2 - current._2)

ובסוף פונקציית החיפוש:

  def findBestPathMutable(start: Node): Option[Int] =
    val map = start.map
    val fin = (
      map.keys.maxBy(_._1)._1,
      map.keys.maxBy(_._2)._2,
    )
    val open = mutable.PriorityQueue()(Ordering.by[(Node, Int), Int](_._2).reverse)
    val g = mutable.HashMap[Node, Int]()
    g.update(start, 0)

    val closed = new mutable.HashSet[Node]()
    open.enqueue((start, 0))

    while (open.nonEmpty) {
      val (current, cost) = open.dequeue()
      println(cost)

      if (current.pos == fin) {
        return Some(g(current))
      }
      closed.add(current)

      current
        .neighbors()
        .filter(!closed.contains(_))
        .foreach { n =>
          val fn = map(n.pos) + g(current)
          if (g.getOrElse(n, Int.MaxValue) > fn) {
            g.put(n, fn)
            open.enqueue((n, fn))
          }
      }
    }
    // no path found
    None

האלגוריתם נקרא A Star ואפשר לקרוא עליו בהרחבה בויקיפדיה כאן:

https://en.wikipedia.org/wiki/A*searchalgorithm

סך הכל התוכנית המלאה היא:


import scala.annotation.tailrec
import scala.collection.immutable.HashMap
import scala.collection.mutable
import scala.io.Source

object aoc2023day17 {

  enum Direction {
    case Up, Down, Left, Right, Start
  }


  case class Node(pos: (Int, Int),
                  direction: Direction,
                  count: Int,
                  map: Map[(Int, Int), Int]) {

    def move(newPosition: (Int, Int), newDirection: Direction): Option[Node] =
      if (((count == 3) && newDirection == direction) || (!map.contains(newPosition))) {
        None
      } else {
        Some(Node(
          pos=newPosition,
          direction=newDirection,
          count=if ((newDirection == direction) || (direction == Direction.Start)) count + 1 else 1,
          map=map))
      }

    def movePart2(newPosition: (Int, Int), newDirection: Direction): Option[Node] =
      if (!map.contains(newPosition)) {
        return None
      }

      if ((newDirection == direction) && (count >= 10)) {
        return None
      }

      if ((newDirection != direction) && (direction != Direction.Start) && (count < 4)) {
        return None
      }

      Some(Node(
        pos = newPosition,
        direction = newDirection,
        count = if ((newDirection == direction) || (direction == Direction.Start)) count + 1 else 1,
        map = map))


    def up: Option[Node] = if (direction == Direction.Down) {
      None
    } else {
      move(newPosition=(pos._1 - 1, pos._2), newDirection=Direction.Up)
    }

    def down: Option[Node] = if (direction == Direction.Up) {
      None
    } else {
      move(newPosition=(pos._1 + 1, pos._2), newDirection=Direction.Down)
    }

    def left: Option[Node] = if (direction == Direction.Right) {
      None
    } else {
      move(newPosition=(pos._1, pos._2 - 1), newDirection=Direction.Left)
    }

    def right: Option[Node] = if (direction == Direction.Left) {
      None
    } else {
      move(newPosition=(pos._1, pos._2 + 1), newDirection=Direction.Right)
    }

    def neighbors(): List[Node] =
      val (row, column) = pos

      List(up, left, right, down)
        .filter(_.isDefined)
        .map(_.get)
  }

  val demoInput: String = """2413432311323
                            |3215453535623
                            |3255245654254
                            |3446585845452
                            |4546657867536
                            |1438598798454
                            |4457876987766
                            |3637877979653
                            |4654967986887
                            |4564679986453
                            |1224686865563
                            |2546548887735
                            |4322674655533""".stripMargin


  def parseInput(input: Source): Map[(Int, Int), Int] =
    input
      .getLines()
      .zipWithIndex
      .collect {
        case (line: String, index: Int) =>
          line.toList.zipWithIndex.map((ch, column) => (index, column, ch))
      }
      .flatten
      .flatMap { case (row, column, ch) => Map((row, column) -> ch.asDigit) }
      .toMap

  def printMatrix(matrix: Map[(Int, Int), Int]): Unit =
    val maxRow = matrix.keys.maxBy(_._1)._1
    val maxColumn = matrix.keys.maxBy(_._2)._2
    0.to(maxRow).foreach { row =>
      0.to(maxColumn).foreach { col =>
        print(matrix((row, col)))
      }
      println()
    }

  def h(current: (Int, Int), goal: (Int, Int)): Int =
   (goal._1 - current._1) + (goal._2 - current._2)

  def findBestPathMutable(start: Node): Option[Int] =
    val map = start.map
    val fin = (
      map.keys.maxBy(_._1)._1,
      map.keys.maxBy(_._2)._2,
    )
    val open = mutable.PriorityQueue()(Ordering.by[(Node, Int), Int](_._2).reverse)
    val g = mutable.HashMap[Node, Int]()
    g.update(start, 0)

    val closed = new mutable.HashSet[Node]()
    open.enqueue((start, 0))

    while (open.nonEmpty) {
      val (current, cost) = open.dequeue()
      println(cost)

      if (current.pos == fin) {
        return Some(g(current))
      }
      closed.add(current)

      current
        .neighbors()
        .filter(!closed.contains(_))
        .foreach { n =>
          val fn = map(n.pos) + g(current)
          if (g.getOrElse(n, Int.MaxValue) > fn) {
            g.put(n, fn)
            open.enqueue((n, fn))
          }
      }
    }
    // no path found
    None

  @main
  def day17part1(): Unit =
//    val map = parseInput(Source.fromResource("day17.txt"))
    val map = parseInput(Source.fromString(demoInput))
    val p = Node((0, 0), Direction.Start, 0, map)
    println(findBestPathMutable(p))
}

ונ.ב. אחרון מצאתי את scastie לא מזמן אז אם בא לכם לשחק עם הקוד ולנסות לשפר אותו אתם יכולים לעשות את זה אונליין בלי שום התקנה פשוט דרך הלינק:

https://scastie.scala-lang.org/JUCoQ1GeSnyoyxnsWLnSGA