JAX para
Arquitectos-Investigadores
Verificación de invariantes topológicos y optimización de formas estructurales. Tutorial completo en español.
Diferenciación automática
Derivadas de cualquier función sin cálculo manual
Compilación JIT
Velocidad de máquina en operaciones repetidas
Vectorización vmap
Opera sobre miles de puntos simultáneamente
GPU / TPU
Escala a hardware especializado sin cambiar código
¿Por qué JAX en una tesis de arquitectura?
JAX es una biblioteca de Python desarrollada por Google. A diferencia de NumPy estándar, JAX añade tres capacidades críticas para la investigación arquitectónica computacional:
Sin JAX
- Derivadas calculadas a mano (propensas a error)
- Métodos numéricos aproximados y lentos
- Optimización por fuerza bruta
- Código que no escala a geometrías complejas
Con JAX
- Diferenciación automática exacta
- Compilación JIT → velocidad de C++
- Gradiente de cualquier función de error
- Escala a GPU/TPU sin reescribir código
Conexión con invariantes topológicos
En la investigación paramétrica-topológica, los invariantes son propiedades que se preservan bajo transformaciones: el equilibrio de un arco, la curvatura media cero de una superficie mínima, la característica de Euler de una malla. JAX permite:
- Definir el invariante como una función matemática.
- Computar su gradiente para detectar desviaciones.
- Optimizar la geometría para restaurar la condición.
JAX no es un software de modelado arquitectónico. Es una plataforma de cómputo diferenciable. Las librerías JAX-FDM y JaxSSO construyen sobre él para ofrecer herramientas específicas de form-finding estructural.
Instalación
Requisitos previos
- Python 3.9 o superior (recomendado 3.10)
- Terminal integrada (VS Code, PyCharm, o terminal del sistema)
- GPU NVIDIA opcional — no necesaria para comenzar
Instalación de JAX base
pip install jax jaxlib
jaxlib puede requerir una versión específica en Windows. Consulta jax.readthedocs.io/en/latest/installation.html si la instalación falla.
Verificar la instalación
import jax import jax.numpy as jnp print(jax.__version__) # Output esperado: 0.4.x
Instalar JAX-FDM (para arcos y estructuras)
pip install jax-fdm
JAX-FDM implementa el método de densidades de fuerzas con diferenciación automática. Desarrollado por Pastrana (Princeton, 2026). Permite modelar y optimizar estructuras de barras y cables.
Instalar JaxSSO (para superficies)
pip install JaxSSO
JaxSSO es una librería para optimización de formas estructurales en superficies. Sus ejemplos en Google Colab son el punto de entrada más efectivo.
Arrays y operaciones básicas
JAX replica la interfaz de NumPy con jnp. Si ya conoces NumPy, la sintaxis es idéntica — la diferencia está bajo el capó (trazabilidad para diferenciación automática).
import jax.numpy as jnp # ── Crear arrays de coordenadas (útil para nodos de una malla) ── x = jnp.array([1.0, 2.0, 3.0]) y = jnp.array([4.0, 5.0, 6.0]) # ── Operaciones vectorizadas ── z = x + y print(z) # [5. 7. 9.] # ── Crear rangos de puntos (para discretizar una curva) ── t = jnp.linspace(0, 1, 100) # 100 puntos entre 0 y 1
Diferenciación automática con grad
grad es la herramienta central de JAX. Transforma cualquier función escalar en su función derivada — de forma exacta, no numérica.
from jax import grad # ── Definir una función (ej: energía potencial de un arco) ── def f(x): return 3 * x ** 2 + 2 * x + 1 # ── grad(f) devuelve la función derivada df/dx ── df_dx = grad(f) # ── Evaluar la derivada en x = 2.0 ── print(df_dx(2.0)) # → 14.0 (porque f'(x) = 6x + 2, y 6·2 + 2 = 14)
Define la función y = a * x**2 + b * x + c con a=1, b=-3, c=2. Usa grad para calcular la pendiente en x=1.5 (el punto de inflexión de la curva).
Derivadas de funciones multivariable
Para geometrías en 2D y 3D, las funciones dependen de múltiples variables. grad con el argumento argnums permite derivar respecto a cualquiera de ellas.
from jax import grad # ── Función de dos variables (ej: superficie paramétrica) ── def g(x, y): return x**2 + y**3 # ── argnums=0 → derivada respecto a x ── dg_dx = grad(g, argnums=0) print(dg_dx(3.0, 2.0)) # → 6.0 (2x evaluado en x=3) # ── argnums=1 → derivada respecto a y ── dg_dy = grad(g, argnums=1) print(dg_dy(3.0, 2.0)) # → 12.0 (3y² evaluado en y=2)
Al definir una superficie como S(u, v), las derivadas parciales ∂S/∂u y ∂S/∂v son exactamente lo que grad con argnums computa. Esto habilita el cálculo directo de la primera y segunda forma fundamental.
Compilación con jit
@jit compila la función en la primera llamada y la ejecuta a velocidad nativa en todas las siguientes. Crítico cuando se iteran miles de pasos de optimización.
from jax import jit import jax.numpy as jnp # ── El decorador @jit activa la compilación automática ── @jit def calcular_curvatura(x): # Simulación de cálculo pesado sobre una geometría for i in range(1000): x = jnp.sin(x) + jnp.cos(x) return x # Primera llamada: compila (~0.5s) resultado = calcular_curvatura(0.5) # Llamadas siguientes: velocidad de máquina (~microsegundos) resultado = calcular_curvatura(1.2)
Primera llamada — Compilación
JAX traza la función y la compila a XLA (Accelerated Linear Algebra). Puede tardar 0.5–2 segundos.
Llamadas posteriores — Ejecución nativa
El código compilado corre directamente en CPU/GPU. Aceleración típica: 10–100×.
Restricción clave
Las estructuras de control que dependen de valores del array (e.g. if x > 0) no pueden compilarse con jit. Usar jnp.where como alternativa.
Vectorización con vmap
vmap transforma una función que opera sobre un solo punto en una función que opera sobre un array completo — sin bucles explícitos.
from jax import vmap import jax.numpy as jnp # ── Función sobre un solo punto ── def altura_catenaria(x): # y = cosh(x) — ecuación de la catenaria return jnp.cosh(x) # ── 200 puntos a lo largo del arco ── puntos_x = jnp.linspace(-2, 2, 200) # ── vmap evalúa los 200 puntos simultáneamente ── alturas = vmap(altura_catenaria)(puntos_x) print(alturas.shape) # (200,)
La composición vmap(grad(f)) calcula el gradiente de f en todos los puntos de una malla simultáneamente. Esta es la base del cálculo de campos de curvatura sobre superficies discretizadas.
Experimento: Arco por el método FDM
El método de densidades de fuerzas (Force Density Method) modela estructuras de barras y cables mediante el parámetro q — la razón entre la fuerza interna y la longitud del elemento. La implementación JAX-FDM garantiza que la solución sea diferenciable respecto a cualquier parámetro de diseño.
El residual de equilibrio nodal debe ser cero. Bajo variaciones paramétricas (altura, densidades de fuerza), JAX-FDM recalcula la forma preservando este invariante.
Código completo — Modelado del arco
import jax.numpy as jnp from jax_fdm import Structure, Node, Edge from jax_fdm.equilibrium import fdm from jax import grad, jit # ════════════════════════════════════════ # 1. GEOMETRÍA DEL ARCO # Coordenadas (x, y, z) — trabajamos en 2D con z=0 # ════════════════════════════════════════ nodos = [ Node(coordinates=[0.0, 0.0, 0.0]), # apoyo izquierdo (fijo) Node(coordinates=[2.0, 1.0, 0.0]), # nodo intermedio Node(coordinates=[4.0, 2.0, 0.0]), # clave del arco Node(coordinates=[6.0, 1.0, 0.0]), # nodo intermedio Node(coordinates=[8.0, 0.0, 0.0]), # apoyo derecho (fijo) ] # Densidades de fuerza q: parámetro clave del FDM # q = fuerza_interna / longitud_elemento aristas = [ Edge(node_i=nodos[0], node_j=nodos[1], q=5.0), Edge(node_i=nodos[1], node_j=nodos[2], q=5.0), Edge(node_i=nodos[2], node_j=nodos[3], q=5.0), Edge(node_i=nodos[3], node_j=nodos[4], q=5.0), ] # Condiciones de borde: apoyos fijos nodos[0].is_fixed = True nodos[4].is_fixed = True # ════════════════════════════════════════ # 2. CARGAS EXTERNAS # Vectores [Fx, Fy, Fz] en cada nodo # ════════════════════════════════════════ cargas = jnp.array([ [0.0, 0.0, 0.0], # apoyo izq — reacción, no carga externa [0.0, -1.0, 0.0], # nodo 1 — carga vertical [0.0, -2.0, 0.0], # clave — carga máxima [0.0, -1.0, 0.0], # nodo 3 — carga vertical [0.0, 0.0, 0.0], # apoyo der — reacción ]) # ════════════════════════════════════════ # 3. CÁLCULO DE LA FORMA DE EQUILIBRIO # ════════════════════════════════════════ estructura = Structure(nodes=nodos, edges=aristas) posiciones_eq, fuerzas = fdm(estructura, cargas) # ════════════════════════════════════════ # 4. VERIFICACIÓN DEL INVARIANTE # residual ≈ 0 → equilibrio garantizado # ════════════════════════════════════════ def calcular_residual(estructura, cargas): pos, fuerzas = fdm(estructura, cargas) return jnp.sum(jnp.abs(fuerzas)) residual = calcular_residual(estructura, cargas) print(f"Residual: {residual:.2e}") # Esperado: ~1e-10 (error de precisión flotante)
Variación paramétrica preservando el invariante
# ════════════════════════════════════════ # Variar la altura de la clave y verificar # que el residual permanece ~0 en toda la familia # ════════════════════════════════════════ def residual_por_altura(param): nodos_mod = [ Node(coordinates=[0.0, 0.0, 0.0]), Node(coordinates=[2.0, 1.0, 0.0]), Node(coordinates=[4.0, param, 0.0]), # ← parámetro libre Node(coordinates=[6.0, 1.0, 0.0]), Node(coordinates=[8.0, 0.0, 0.0]), ] nodos_mod[0].is_fixed = True nodos_mod[4].is_fixed = True aristas_mod = [ Edge(node_i=nodos_mod[i], node_j=nodos_mod[i+1], q=5.0) for i in range(4) ] est = Structure(nodes=nodos_mod, edges=aristas_mod) return calcular_residual(est, cargas) # Familia de arcos con alturas 1.5 → 3.0 for h in [1.5, 2.0, 2.5, 3.0]: print(f"h={h:.1f}m → residual={residual_por_altura(h):.2e}")
Modifica el valor de q entre 1.0 y 10.0 para cada arista. Observa cómo la forma del arco cambia mientras el residual permanece ≈ 0. Esto ilustra la degeneración paramétrica de la familia de equilibrio.
Superficies mínimas con JaxSSO
Una superficie mínima tiene curvatura media H = 0 en todos sus puntos. JaxSSO permite optimizar una malla arbitraria hacia esta condición usando el gradiente diferenciable de JAX.
import jax.numpy as jnp from JaxSSO import Surface, optimize_surface, mean_curvature from jax import jit # ════════════════════════════════════════ # 1. SUPERFICIE INICIAL — catenoide paramétrica # Parámetros u ∈ [0, 2π], v ∈ [-1, 1] # ════════════════════════════════════════ def catenoide(u, v): r = 1.0 return jnp.array([ r * jnp.cosh(v) * jnp.cos(u), # X(u,v) r * jnp.cosh(v) * jnp.sin(u), # Y(u,v) v # Z(u,v) ]) # ════════════════════════════════════════ # 2. DISCRETIZACIÓN — malla 20×20 # ════════════════════════════════════════ u_vals = jnp.linspace(0, 2*jnp.pi, 20) v_vals = jnp.linspace(-1, 1, 20) U, V = jnp.meshgrid(u_vals, v_vals) puntos = jnp.array([catenoide(U, V)]) # ════════════════════════════════════════ # 3. OPTIMIZACIÓN → H = 0 # JaxSSO minimiza la desviación de la curvatura media # ════════════════════════════════════════ superficie_opt = optimize_surface( puntos, target_curvature='mean', value=0.0 ) # ════════════════════════════════════════ # 4. VERIFICAR INVARIANTE # ════════════════════════════════════════ H = mean_curvature(superficie_opt) print(f"H promedio: {jnp.mean(H):.2e}") # Esperado: ~0.0 (dentro de tolerancia numérica)
Conexión con Grasshopper
Dos rutas de integración:
- CSV export — Exportar las coordenadas de la malla optimizada como
.csve importar en Grasshopper con el componente Import CSV. - Hops — El componente Grasshopper Hops ejecuta funciones Python directamente. Permite llamar funciones JAX desde una definición paramétrica en tiempo real.
Mapa de experimentos
| Experimento | Invariante topológico | Herramienta JAX | Verificación |
|---|---|---|---|
| Exp. 3 — Cadena catenaria | Equilibrio en compresión pura | JAX-FDM + fdm() |
Residual nodal → 0 |
| Exp. 4 — Superficies de Weingarten | Relación entre curvaturas principales | JaxSSO + mean_curvature() |
H = 0 en toda la malla |
| Exp. 5 — Cubierta integrada | Múltiples invariantes simultáneos | JAX-FDM + JaxSSO combinados | Residual + H verificados en cada iteración |
Recursos adicionales
- jax.readthedocs.io — Documentación oficial de JAX
- github.com/compas-dev/jax-fdm — Repositorio JAX-FDM
- Google Colab — Ejemplos de JaxSSO (buscar notebooks en el repositorio)
Créditos y referencias
Este tutorial se apoya en dos contribuciones académicas directas. Se detallan a continuación las autorías y la naturaleza específica de su aportación.
Adiels & Williams (2025)
Emil Adiels
Chalmers University of Technology, Suecia · Department of Architecture and Civil Engineering
Chris J.K. Williams
University of Bath, Reino Unido · Department of Architecture and Civil Engineering
El marco conceptual para el cálculo de form-finding mediante diferenciación automática. Su trabajo sobre modelado de estructuras con resortes y gradientes automáticos fundamenta el enfoque metodológico con el que JAX se aplica a la verificación de invariantes estructurales — equilibrio y curvatura — en los Experimentos 3, 4 y 5.
Pastrana (2026)
Armando Rafael Pastrana Jiménez
Princeton University, Estados Unidos · School of Architecture
Desarrollo de la librería JAX-FDM — implementación del método de densidades de fuerzas (Force Density Method) con diferenciación automática sobre JAX. El Experimento del Arco (sección 08) utiliza directamente esta librería: las clases Structure, Node, Edge y la función fdm() son de su autoría.