theorem cesHessianQF_eq_sum (ρ c : ℝ) (v : Fin J → ℝ) :
cesHessianQF J ρ c v =
∑ i : Fin J, ∑ j : Fin J, cesHessianEntry J ρ c i j * v i * v j := by
simp only [cesHessianQF, cesHessianEntry]
-- (∑ vⱼ)² = ∑ᵢ ∑ⱼ vᵢ·vⱼ
have sq_eq : (∑ j : Fin J, v j) ^ 2 = ∑ i : Fin J, ∑ j : Fin J, v i * v j := by
rw [sq, Fintype.sum_mul_sum]
-- ∑ vⱼ² = ∑ᵢ ∑ⱼ (if i=j then vᵢ·vⱼ else 0)
have diag_eq : ∑ j : Fin J, v j ^ 2 =
∑ i : Fin J, ∑ j : Fin J, if i = j then v i * v j else 0 := by
congr 1; ext i
rw [Finset.sum_ite_eq Finset.univ i (fun j => v i * v j)]
simp [sq]
-- RHS: Σᵢ Σⱼ (1-ρ)/(J²c) · (1 - J·δᵢⱼ) · vᵢ · vⱼ
-- = (1-ρ)/(J²c) · [Σᵢ Σⱼ vᵢvⱼ - J · Σᵢ vᵢ²]
-- = (1-ρ)/(J²c) · [(Σvⱼ)² - J·Σvⱼ²] = LHS
-- Transform the RHS sum
have rhs_eq : ∑ i : Fin J, ∑ j : Fin J,
(1 - ρ) / (↑J ^ 2 * c) * (1 - if i = j then (↑J : ℝ) else 0) * v i * v j =
(1 - ρ) / (↑J ^ 2 * c) *
((∑ i : Fin J, ∑ j : Fin J, v i * v j) -
↑J * ∑ i : Fin J, v i ^ 2) := by
-- Split each term and separate sums
have h1 : ∀ i j : Fin J,
(1 - ρ) / (↑J ^ 2 * c) * (1 - if i = j then (↑J : ℝ) else 0) * v i * v j =
(1 - ρ) / (↑J ^ 2 * c) * (v i * v j) +
(1 - ρ) / (↑J ^ 2 * c) * (-(if i = j then ↑J * v i * v j else 0)) := by
intro i j; split_ifs <;> ring
simp_rw [h1, Finset.sum_add_distrib]
-- First part: Σᵢ Σⱼ k*(vᵢvⱼ) = k * Σᵢ Σⱼ vᵢvⱼ
have p1 : ∑ i : Fin J, ∑ j : Fin J, (1 - ρ) / (↑J ^ 2 * c) * (v i * v j) =
(1 - ρ) / (↑J ^ 2 * c) * ∑ i : Fin J, ∑ j : Fin J, v i * v j := by
simp_rw [Finset.mul_sum]
-- Second part: Σᵢ Σⱼ k*(-δᵢⱼ·J·vᵢvⱼ) = -k·J·Σvᵢ²
have p2 : ∑ i : Fin J, ∑ j : Fin J,
(1 - ρ) / (↑J ^ 2 * c) *
-(if i = j then ↑J * v i * v j else 0) =
-((1 - ρ) / (↑J ^ 2 * c) *
(↑J * ∑ i : Fin J, v i ^ 2)) := by
simp_rw [← Finset.mul_sum]
-- Pull neg out of double sum
simp_rw [Finset.sum_neg_distrib]
rw [mul_neg]
congr 1; congr 1
rw [Finset.mul_sum]; congr 1; ext i
rw [Finset.sum_ite_eq Finset.univ i]; simp [sq, mul_assoc]
rw [p1, p2]; ring
rw [rhs_eq, ← sq_eq]Gradient and Hessian of CES at the symmetric point.