ARQUITECTURA COMPUTACIONAL · PYTHON

JAX para
Arquitectos-Investigadores

Verificación de invariantes topológicos y optimización de formas estructurales. Tutorial completo en español.

Diferenciación automática

Derivadas de cualquier función sin cálculo manual

Compilación JIT

Velocidad de máquina en operaciones repetidas

Vectorización vmap

Opera sobre miles de puntos simultáneamente

GPU / TPU

Escala a hardware especializado sin cambiar código

// 01

¿Por qué JAX en una tesis de arquitectura?

JAX es una biblioteca de Python desarrollada por Google. A diferencia de NumPy estándar, JAX añade tres capacidades críticas para la investigación arquitectónica computacional:

Sin JAX

  • Derivadas calculadas a mano (propensas a error)
  • Métodos numéricos aproximados y lentos
  • Optimización por fuerza bruta
  • Código que no escala a geometrías complejas

Con JAX

  • Diferenciación automática exacta
  • Compilación JIT → velocidad de C++
  • Gradiente de cualquier función de error
  • Escala a GPU/TPU sin reescribir código

Conexión con invariantes topológicos

En la investigación paramétrica-topológica, los invariantes son propiedades que se preservan bajo transformaciones: el equilibrio de un arco, la curvatura media cero de una superficie mínima, la característica de Euler de una malla. JAX permite:

  1. Definir el invariante como una función matemática.
  2. Computar su gradiente para detectar desviaciones.
  3. Optimizar la geometría para restaurar la condición.
Precisión conceptual

JAX no es un software de modelado arquitectónico. Es una plataforma de cómputo diferenciable. Las librerías JAX-FDM y JaxSSO construyen sobre él para ofrecer herramientas específicas de form-finding estructural.

// 02

Instalación

Requisitos previos

Instalación de JAX base

TERMINAL
pip install jax jaxlib
Windows

jaxlib puede requerir una versión específica en Windows. Consulta jax.readthedocs.io/en/latest/installation.html si la instalación falla.

Verificar la instalación

Python
import jax
import jax.numpy as jnp

print(jax.__version__)
# Output esperado: 0.4.x

Instalar JAX-FDM (para arcos y estructuras)

TERMINAL
pip install jax-fdm

JAX-FDM implementa el método de densidades de fuerzas con diferenciación automática. Desarrollado por Pastrana (Princeton, 2026). Permite modelar y optimizar estructuras de barras y cables.

Instalar JaxSSO (para superficies)

TERMINAL
pip install JaxSSO

JaxSSO es una librería para optimización de formas estructurales en superficies. Sus ejemplos en Google Colab son el punto de entrada más efectivo.


// 03

Arrays y operaciones básicas

JAX replica la interfaz de NumPy con jnp. Si ya conoces NumPy, la sintaxis es idéntica — la diferencia está bajo el capó (trazabilidad para diferenciación automática).

Python arrays_basicos.py
import jax.numpy as jnp

# ── Crear arrays de coordenadas (útil para nodos de una malla) ──
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])

# ── Operaciones vectorizadas ──
z = x + y
print(z)  # [5. 7. 9.]

# ── Crear rangos de puntos (para discretizar una curva) ──
t = jnp.linspace(0, 1, 100)  # 100 puntos entre 0 y 1
Representación de un arco como array de coordenadas
nodo[0] nodo[1] clave nodo[3] nodo[4] jnp.array([[0,110], [160,35], [240,10], [320,35], [440,110]])
// 04

Diferenciación automática con grad

grad es la herramienta central de JAX. Transforma cualquier función escalar en su función derivada — de forma exacta, no numérica.

Flujo de diferenciación automática
f(x) función grad(f) transforma df/dx (x) derivada exacta
Python diferenciacion.py
from jax import grad

# ── Definir una función (ej: energía potencial de un arco) ──
def f(x):
    return 3 * x ** 2 + 2 * x + 1

# ── grad(f) devuelve la función derivada df/dx ──
df_dx = grad(f)

# ── Evaluar la derivada en x = 2.0 ──
print(df_dx(2.0))
# → 14.0   (porque f'(x) = 6x + 2, y 6·2 + 2 = 14)
Ejercicio 1 — Pendiente de una parábola

Define la función y = a * x**2 + b * x + c con a=1, b=-3, c=2. Usa grad para calcular la pendiente en x=1.5 (el punto de inflexión de la curva).

// 05

Derivadas de funciones multivariable

Para geometrías en 2D y 3D, las funciones dependen de múltiples variables. grad con el argumento argnums permite derivar respecto a cualquiera de ellas.

Python multivariable.py
from jax import grad

# ── Función de dos variables (ej: superficie paramétrica) ──
def g(x, y):
    return x**2 + y**3

# ── argnums=0 → derivada respecto a x ──
dg_dx = grad(g, argnums=0)
print(dg_dx(3.0, 2.0))  # → 6.0   (2x evaluado en x=3)

# ── argnums=1 → derivada respecto a y ──
dg_dy = grad(g, argnums=1)
print(dg_dy(3.0, 2.0))  # → 12.0  (3y² evaluado en y=2)
Relevancia para superficies de Weingarten

Al definir una superficie como S(u, v), las derivadas parciales ∂S/∂u y ∂S/∂v son exactamente lo que grad con argnums computa. Esto habilita el cálculo directo de la primera y segunda forma fundamental.

// 06

Compilación con jit

@jit compila la función en la primera llamada y la ejecuta a velocidad nativa en todas las siguientes. Crítico cuando se iteran miles de pasos de optimización.

Python compilacion_jit.py
from jax import jit
import jax.numpy as jnp

# ── El decorador @jit activa la compilación automática ──
@jit
def calcular_curvatura(x):
    # Simulación de cálculo pesado sobre una geometría
    for i in range(1000):
        x = jnp.sin(x) + jnp.cos(x)
    return x

# Primera llamada: compila (~0.5s)
resultado = calcular_curvatura(0.5)

# Llamadas siguientes: velocidad de máquina (~microsegundos)
resultado = calcular_curvatura(1.2)
1

Primera llamada — Compilación

JAX traza la función y la compila a XLA (Accelerated Linear Algebra). Puede tardar 0.5–2 segundos.

2

Llamadas posteriores — Ejecución nativa

El código compilado corre directamente en CPU/GPU. Aceleración típica: 10–100×.

3

Restricción clave

Las estructuras de control que dependen de valores del array (e.g. if x > 0) no pueden compilarse con jit. Usar jnp.where como alternativa.

// 07

Vectorización con vmap

vmap transforma una función que opera sobre un solo punto en una función que opera sobre un array completo — sin bucles explícitos.

Python vectorizacion_vmap.py
from jax import vmap
import jax.numpy as jnp

# ── Función sobre un solo punto ──
def altura_catenaria(x):
    # y = cosh(x) — ecuación de la catenaria
    return jnp.cosh(x)

# ── 200 puntos a lo largo del arco ──
puntos_x = jnp.linspace(-2, 2, 200)

# ── vmap evalúa los 200 puntos simultáneamente ──
alturas = vmap(altura_catenaria)(puntos_x)

print(alturas.shape)  # (200,)
vmap + grad combinados

La composición vmap(grad(f)) calcula el gradiente de f en todos los puntos de una malla simultáneamente. Esta es la base del cálculo de campos de curvatura sobre superficies discretizadas.


// 08

Experimento: Arco por el método FDM

El método de densidades de fuerzas (Force Density Method) modela estructuras de barras y cables mediante el parámetro q — la razón entre la fuerza interna y la longitud del elemento. La implementación JAX-FDM garantiza que la solución sea diferenciable respecto a cualquier parámetro de diseño.

Invariante verificado

El residual de equilibrio nodal debe ser cero. Bajo variaciones paramétricas (altura, densidades de fuerza), JAX-FDM recalcula la forma preservando este invariante.

Código completo — Modelado del arco

Python arco_fdm.py
import jax.numpy as jnp
from jax_fdm import Structure, Node, Edge
from jax_fdm.equilibrium import fdm
from jax import grad, jit

# ════════════════════════════════════════
# 1. GEOMETRÍA DEL ARCO
# Coordenadas (x, y, z) — trabajamos en 2D con z=0
# ════════════════════════════════════════
nodos = [
    Node(coordinates=[0.0, 0.0, 0.0]),  # apoyo izquierdo (fijo)
    Node(coordinates=[2.0, 1.0, 0.0]),  # nodo intermedio
    Node(coordinates=[4.0, 2.0, 0.0]),  # clave del arco
    Node(coordinates=[6.0, 1.0, 0.0]),  # nodo intermedio
    Node(coordinates=[8.0, 0.0, 0.0]),  # apoyo derecho (fijo)
]

# Densidades de fuerza q: parámetro clave del FDM
# q = fuerza_interna / longitud_elemento
aristas = [
    Edge(node_i=nodos[0], node_j=nodos[1], q=5.0),
    Edge(node_i=nodos[1], node_j=nodos[2], q=5.0),
    Edge(node_i=nodos[2], node_j=nodos[3], q=5.0),
    Edge(node_i=nodos[3], node_j=nodos[4], q=5.0),
]

# Condiciones de borde: apoyos fijos
nodos[0].is_fixed = True
nodos[4].is_fixed = True

# ════════════════════════════════════════
# 2. CARGAS EXTERNAS
# Vectores [Fx, Fy, Fz] en cada nodo
# ════════════════════════════════════════
cargas = jnp.array([
    [0.0,  0.0, 0.0],  # apoyo izq — reacción, no carga externa
    [0.0, -1.0, 0.0],  # nodo 1 — carga vertical
    [0.0, -2.0, 0.0],  # clave — carga máxima
    [0.0, -1.0, 0.0],  # nodo 3 — carga vertical
    [0.0,  0.0, 0.0],  # apoyo der — reacción
])

# ════════════════════════════════════════
# 3. CÁLCULO DE LA FORMA DE EQUILIBRIO
# ════════════════════════════════════════
estructura = Structure(nodes=nodos, edges=aristas)
posiciones_eq, fuerzas = fdm(estructura, cargas)

# ════════════════════════════════════════
# 4. VERIFICACIÓN DEL INVARIANTE
# residual ≈ 0 → equilibrio garantizado
# ════════════════════════════════════════
def calcular_residual(estructura, cargas):
    pos, fuerzas = fdm(estructura, cargas)
    return jnp.sum(jnp.abs(fuerzas))

residual = calcular_residual(estructura, cargas)
print(f"Residual: {residual:.2e}")
# Esperado: ~1e-10 (error de precisión flotante)

Variación paramétrica preservando el invariante

Python variacion_altura.py
# ════════════════════════════════════════
# Variar la altura de la clave y verificar
# que el residual permanece ~0 en toda la familia
# ════════════════════════════════════════
def residual_por_altura(param):
    nodos_mod = [
        Node(coordinates=[0.0, 0.0, 0.0]),
        Node(coordinates=[2.0, 1.0, 0.0]),
        Node(coordinates=[4.0, param, 0.0]),  # ← parámetro libre
        Node(coordinates=[6.0, 1.0, 0.0]),
        Node(coordinates=[8.0, 0.0, 0.0]),
    ]
    nodos_mod[0].is_fixed = True
    nodos_mod[4].is_fixed = True

    aristas_mod = [
        Edge(node_i=nodos_mod[i], node_j=nodos_mod[i+1], q=5.0)
        for i in range(4)
    ]
    est = Structure(nodes=nodos_mod, edges=aristas_mod)
    return calcular_residual(est, cargas)

# Familia de arcos con alturas 1.5 → 3.0
for h in [1.5, 2.0, 2.5, 3.0]:
    print(f"h={h:.1f}m → residual={residual_por_altura(h):.2e}")
Ejercicio 2 — Densidades de fuerza variables

Modifica el valor de q entre 1.0 y 10.0 para cada arista. Observa cómo la forma del arco cambia mientras el residual permanece ≈ 0. Esto ilustra la degeneración paramétrica de la familia de equilibrio.

// 09

Superficies mínimas con JaxSSO

Una superficie mínima tiene curvatura media H = 0 en todos sus puntos. JaxSSO permite optimizar una malla arbitraria hacia esta condición usando el gradiente diferenciable de JAX.

Curvatura media H = (κ₁ + κ₂) / 2
κ₁ > 0 κ₂ < 0 H = (κ₁ + κ₂)/2 = 0 curvaturas principales iguales y opuestas
Python superficie_minima.py
import jax.numpy as jnp
from JaxSSO import Surface, optimize_surface, mean_curvature
from jax import jit

# ════════════════════════════════════════
# 1. SUPERFICIE INICIAL — catenoide paramétrica
# Parámetros u ∈ [0, 2π], v ∈ [-1, 1]
# ════════════════════════════════════════
def catenoide(u, v):
    r = 1.0
    return jnp.array([
        r * jnp.cosh(v) * jnp.cos(u),  # X(u,v)
        r * jnp.cosh(v) * jnp.sin(u),  # Y(u,v)
        v                                  # Z(u,v)
    ])

# ════════════════════════════════════════
# 2. DISCRETIZACIÓN — malla 20×20
# ════════════════════════════════════════
u_vals = jnp.linspace(0, 2*jnp.pi, 20)
v_vals = jnp.linspace(-1, 1, 20)
U, V = jnp.meshgrid(u_vals, v_vals)
puntos = jnp.array([catenoide(U, V)])

# ════════════════════════════════════════
# 3. OPTIMIZACIÓN → H = 0
# JaxSSO minimiza la desviación de la curvatura media
# ════════════════════════════════════════
superficie_opt = optimize_surface(
    puntos,
    target_curvature='mean',
    value=0.0
)

# ════════════════════════════════════════
# 4. VERIFICAR INVARIANTE
# ════════════════════════════════════════
H = mean_curvature(superficie_opt)
print(f"H promedio: {jnp.mean(H):.2e}")
# Esperado: ~0.0 (dentro de tolerancia numérica)

Conexión con Grasshopper

Dos rutas de integración:

// 10

Mapa de experimentos

Experimento Invariante topológico Herramienta JAX Verificación
Exp. 3 — Cadena catenaria Equilibrio en compresión pura JAX-FDM + fdm() Residual nodal → 0
Exp. 4 — Superficies de Weingarten Relación entre curvaturas principales JaxSSO + mean_curvature() H = 0 en toda la malla
Exp. 5 — Cubierta integrada Múltiples invariantes simultáneos JAX-FDM + JaxSSO combinados Residual + H verificados en cada iteración

Recursos adicionales

// 11

Créditos y referencias

Este tutorial se apoya en dos contribuciones académicas directas. Se detallan a continuación las autorías y la naturaleza específica de su aportación.

Adiels & Williams (2025)

A

Emil Adiels

Chalmers University of Technology, Suecia · Department of Architecture and Civil Engineering

W

Chris J.K. Williams

University of Bath, Reino Unido · Department of Architecture and Civil Engineering

Aportación a este tutorial

El marco conceptual para el cálculo de form-finding mediante diferenciación automática. Su trabajo sobre modelado de estructuras con resortes y gradientes automáticos fundamenta el enfoque metodológico con el que JAX se aplica a la verificación de invariantes estructurales — equilibrio y curvatura — en los Experimentos 3, 4 y 5.

Pastrana (2026)

P

Armando Rafael Pastrana Jiménez

Princeton University, Estados Unidos · School of Architecture

Aportación a este tutorial

Desarrollo de la librería JAX-FDM — implementación del método de densidades de fuerzas (Force Density Method) con diferenciación automática sobre JAX. El Experimento del Arco (sección 08) utiliza directamente esta librería: las clases Structure, Node, Edge y la función fdm() son de su autoría.