CES Hessian QF Eq Sum

Documentation

Lean 4 Proof

theorem cesHessianQF_eq_sum (ρ c : ℝ) (v : Fin J → ℝ) :
    cesHessianQF J ρ c v =
    ∑ i : Fin J, ∑ j : Fin J, cesHessianEntry J ρ c i j * v i * v j := by
  simp only [cesHessianQF, cesHessianEntry]
  -- (∑ vⱼ)² = ∑ᵢ ∑ⱼ vᵢ·vⱼ
  have sq_eq : (∑ j : Fin J, v j) ^ 2 = ∑ i : Fin J, ∑ j : Fin J, v i * v j := by
    rw [sq, Fintype.sum_mul_sum]
  -- ∑ vⱼ² = ∑ᵢ ∑ⱼ (if i=j then vᵢ·vⱼ else 0)
  have diag_eq : ∑ j : Fin J, v j ^ 2 =
      ∑ i : Fin J, ∑ j : Fin J, if i = j then v i * v j else 0 := by
    congr 1; ext i
    rw [Finset.sum_ite_eq Finset.univ i (fun j => v i * v j)]
    simp [sq]
  -- RHS: Σᵢ Σⱼ (1-ρ)/(J²c) · (1 - J·δᵢⱼ) · vᵢ · vⱼ
  -- = (1-ρ)/(J²c) · [Σᵢ Σⱼ vᵢvⱼ - J · Σᵢ vᵢ²]
  -- = (1-ρ)/(J²c) · [(Σvⱼ)² - J·Σvⱼ²]  = LHS
  -- Transform the RHS sum
  have rhs_eq : ∑ i : Fin J, ∑ j : Fin J,
      (1 - ρ) / (↑J ^ 2 * c) * (1 - if i = j then (↑J : ℝ) else 0) * v i * v j =
    (1 - ρ) / (↑J ^ 2 * c) *
      ((∑ i : Fin J, ∑ j : Fin J, v i * v j) -
       ↑J * ∑ i : Fin J, v i ^ 2) := by
    -- Split each term and separate sums
    have h1 : ∀ i j : Fin J,
        (1 - ρ) / (↑J ^ 2 * c) * (1 - if i = j then (↑J : ℝ) else 0) * v i * v j =
        (1 - ρ) / (↑J ^ 2 * c) * (v i * v j) +
        (1 - ρ) / (↑J ^ 2 * c) * (-(if i = j then ↑J * v i * v j else 0)) := by
      intro i j; split_ifs <;> ring
    simp_rw [h1, Finset.sum_add_distrib]
    -- First part: Σᵢ Σⱼ k*(vᵢvⱼ) = k * Σᵢ Σⱼ vᵢvⱼ
    have p1 : ∑ i : Fin J, ∑ j : Fin J, (1 - ρ) / (↑J ^ 2 * c) * (v i * v j) =
      (1 - ρ) / (↑J ^ 2 * c) * ∑ i : Fin J, ∑ j : Fin J, v i * v j := by
      simp_rw [Finset.mul_sum]
    -- Second part: Σᵢ Σⱼ k*(-δᵢⱼ·J·vᵢvⱼ) = -k·J·Σvᵢ²
    have p2 : ∑ i : Fin J, ∑ j : Fin J,
        (1 - ρ) / (↑J ^ 2 * c) *
          -(if i = j then ↑J * v i * v j else 0) =
      -((1 - ρ) / (↑J ^ 2 * c) *
        (↑J * ∑ i : Fin J, v i ^ 2)) := by
      simp_rw [← Finset.mul_sum]
      -- Pull neg out of double sum
      simp_rw [Finset.sum_neg_distrib]
      rw [mul_neg]
      congr 1; congr 1
      rw [Finset.mul_sum]; congr 1; ext i
      rw [Finset.sum_ite_eq Finset.univ i]; simp [sq, mul_assoc]
    rw [p1, p2]; ring
  rw [rhs_eq, ← sq_eq]

Dependency Graph

Module Section

Gradient and Hessian of CES at the symmetric point.